|
18 | 18 | # under the License. |
19 | 19 | # |
20 | 20 |
|
| 21 | +import base64 |
21 | 22 | import math |
22 | 23 | import requests |
23 | 24 | from typing import List |
|
29 | 30 | from enum import Enum |
30 | 31 | import json |
31 | 32 | from fastavro.schema import load_schema |
| 33 | +from google.protobuf import descriptor_pb2, descriptor_pool, message_factory |
| 34 | + |
| 35 | + |
| 36 | +def _add_protobuf_field(message, name, number, field_type, type_name=None): |
| 37 | + field = message.field.add() |
| 38 | + field.name = name |
| 39 | + field.number = number |
| 40 | + field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL |
| 41 | + field.type = field_type |
| 42 | + if type_name: |
| 43 | + field.type_name = type_name |
| 44 | + |
| 45 | + |
| 46 | +def _get_message_classes(pool, message_names): |
| 47 | + if hasattr(message_factory, 'GetMessageClass'): |
| 48 | + return tuple( |
| 49 | + message_factory.GetMessageClass(pool.FindMessageTypeByName(message_name)) |
| 50 | + for message_name in message_names |
| 51 | + ) |
| 52 | + factory = message_factory.MessageFactory(pool) |
| 53 | + return tuple( |
| 54 | + factory.GetPrototype(pool.FindMessageTypeByName(message_name)) |
| 55 | + for message_name in message_names |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +def _build_protobuf_test_messages(): |
| 60 | + file_proto = descriptor_pb2.FileDescriptorProto() |
| 61 | + file_proto.name = 'test_schema.proto' |
| 62 | + file_proto.package = 'test' |
| 63 | + file_proto.syntax = 'proto3' |
| 64 | + |
| 65 | + test_message = file_proto.message_type.add() |
| 66 | + test_message.name = 'TestMessage' |
| 67 | + _add_protobuf_field(test_message, 'name', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) |
| 68 | + _add_protobuf_field(test_message, 'value', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) |
| 69 | + |
| 70 | + nested_message = file_proto.message_type.add() |
| 71 | + nested_message.name = 'TestMessageWithNested' |
| 72 | + _add_protobuf_field(nested_message, 'str_field', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) |
| 73 | + _add_protobuf_field(nested_message, 'int_field', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) |
| 74 | + _add_protobuf_field(nested_message, 'double_field', 3, descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE) |
| 75 | + _add_protobuf_field( |
| 76 | + nested_message, 'nested', 4, descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, '.test.TestInner' |
| 77 | + ) |
| 78 | + |
| 79 | + inner_message = file_proto.message_type.add() |
| 80 | + inner_message.name = 'TestInner' |
| 81 | + _add_protobuf_field(inner_message, 'inner_str', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) |
| 82 | + _add_protobuf_field(inner_message, 'inner_int', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT64) |
| 83 | + |
| 84 | + pool = descriptor_pool.DescriptorPool() |
| 85 | + pool.AddSerializedFile(file_proto.SerializeToString()) |
| 86 | + return _get_message_classes( |
| 87 | + pool, |
| 88 | + ('test.TestMessage', 'test.TestMessageWithNested', 'test.TestInner'), |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +TestMessage, TestMessageWithNested, TestInner = _build_protobuf_test_messages() |
| 93 | + |
32 | 94 |
|
33 | 95 | class ExampleRecord(Record): |
34 | 96 | str_field = String() |
@@ -1404,5 +1466,90 @@ def test_schema_type_promotion(self): |
1404 | 1466 | client.close() |
1405 | 1467 |
|
1406 | 1468 |
|
| 1469 | +class ProtobufNativeSchemaTest(TestCase): |
| 1470 | + """Unit tests for ProtobufNativeSchema (no Pulsar broker required).""" |
| 1471 | + |
| 1472 | + def test_schema_type(self): |
| 1473 | + """Schema type must be PROTOBUF_NATIVE.""" |
| 1474 | + import _pulsar |
| 1475 | + schema = ProtobufNativeSchema(TestMessage) |
| 1476 | + self.assertEqual(schema.schema_info().schema_type(), _pulsar.SchemaType.PROTOBUF_NATIVE) |
| 1477 | + |
| 1478 | + def test_schema_definition_keys(self): |
| 1479 | + """Schema definition JSON must contain the three required keys.""" |
| 1480 | + schema = ProtobufNativeSchema(TestMessage) |
| 1481 | + schema_def = json.loads(schema.schema_info().schema()) |
| 1482 | + self.assertIn('fileDescriptorSet', schema_def) |
| 1483 | + self.assertIn('rootMessageTypeName', schema_def) |
| 1484 | + self.assertIn('rootFileDescriptorName', schema_def) |
| 1485 | + |
| 1486 | + def test_schema_definition_values(self): |
| 1487 | + """rootMessageTypeName and rootFileDescriptorName must match the descriptor.""" |
| 1488 | + schema = ProtobufNativeSchema(TestMessage) |
| 1489 | + schema_def = json.loads(schema.schema_info().schema()) |
| 1490 | + self.assertEqual(schema_def['rootMessageTypeName'], 'test.TestMessage') |
| 1491 | + self.assertEqual(schema_def['rootFileDescriptorName'], 'test_schema.proto') |
| 1492 | + |
| 1493 | + def test_file_descriptor_set_is_valid_base64_proto(self): |
| 1494 | + """fileDescriptorSet must be valid base64-encoded FileDescriptorSet bytes.""" |
| 1495 | + from google.protobuf import descriptor_pb2 |
| 1496 | + schema = ProtobufNativeSchema(TestMessage) |
| 1497 | + schema_def = json.loads(schema.schema_info().schema()) |
| 1498 | + raw = base64.b64decode(schema_def['fileDescriptorSet']) |
| 1499 | + fds = descriptor_pb2.FileDescriptorSet.FromString(raw) |
| 1500 | + file_names = [f.name for f in fds.file] |
| 1501 | + self.assertIn('test_schema.proto', file_names) |
| 1502 | + |
| 1503 | + def test_encode_decode_roundtrip(self): |
| 1504 | + """encode then decode must reproduce the original message.""" |
| 1505 | + schema = ProtobufNativeSchema(TestMessage) |
| 1506 | + original = TestMessage(name='hello', value=42) |
| 1507 | + encoded = schema.encode(original) |
| 1508 | + decoded = schema.decode(encoded) |
| 1509 | + self.assertEqual(decoded.name, 'hello') |
| 1510 | + self.assertEqual(decoded.value, 42) |
| 1511 | + |
| 1512 | + def test_encode_produces_protobuf_binary(self): |
| 1513 | + """Encoded bytes must be valid protobuf binary (parseable by the class directly).""" |
| 1514 | + schema = ProtobufNativeSchema(TestMessage) |
| 1515 | + msg = TestMessage(name='pulsar', value=100) |
| 1516 | + encoded = schema.encode(msg) |
| 1517 | + # Verify with protobuf's own parser |
| 1518 | + reparsed = TestMessage.FromString(encoded) |
| 1519 | + self.assertEqual(reparsed, msg) |
| 1520 | + |
| 1521 | + def test_encode_decode_nested_message(self): |
| 1522 | + """encode/decode round-trip works for messages containing nested message fields.""" |
| 1523 | + schema = ProtobufNativeSchema(TestMessageWithNested) |
| 1524 | + original = TestMessageWithNested( |
| 1525 | + str_field='test', |
| 1526 | + int_field=7, |
| 1527 | + double_field=3.14, |
| 1528 | + nested=TestInner(inner_str='inner', inner_int=999), |
| 1529 | + ) |
| 1530 | + decoded = schema.decode(schema.encode(original)) |
| 1531 | + self.assertEqual(decoded.str_field, 'test') |
| 1532 | + self.assertEqual(decoded.int_field, 7) |
| 1533 | + self.assertAlmostEqual(decoded.double_field, 3.14) |
| 1534 | + self.assertEqual(decoded.nested.inner_str, 'inner') |
| 1535 | + self.assertEqual(decoded.nested.inner_int, 999) |
| 1536 | + |
| 1537 | + def test_wrong_type_raises(self): |
| 1538 | + """Encoding an object of the wrong type must raise TypeError.""" |
| 1539 | + schema = ProtobufNativeSchema(TestMessage) |
| 1540 | + with self.assertRaises(TypeError): |
| 1541 | + schema.encode("not a protobuf message") |
| 1542 | + |
| 1543 | + def test_non_message_class_raises(self): |
| 1544 | + """Constructing with a non-Message class must raise TypeError.""" |
| 1545 | + with self.assertRaises(TypeError): |
| 1546 | + ProtobufNativeSchema(str) |
| 1547 | + |
| 1548 | + def test_schema_name(self): |
| 1549 | + """Schema name must be 'PROTOBUF_NATIVE'.""" |
| 1550 | + schema = ProtobufNativeSchema(TestMessage) |
| 1551 | + self.assertEqual(schema.schema_info().name(), 'PROTOBUF_NATIVE') |
| 1552 | + |
| 1553 | + |
1407 | 1554 | if __name__ == '__main__': |
1408 | 1555 | main() |
0 commit comments