Skip to content

Commit 7e03a4d

Browse files
feat: implement ProtobufNative schema (#299)
1 parent 5d38ac9 commit 7e03a4d

6 files changed

Lines changed: 305 additions & 4 deletions

File tree

.github/workflows/ci-pr-validation.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
python3 -m pip install -U pip setuptools wheel requests
7979
python3 setup.py bdist_wheel
8080
WHEEL=$(find dist -name '*.whl')
81-
pip3 install ${WHEEL}[avro]
81+
pip3 install ${WHEEL}[avro,protobuf]
8282
8383
- name: Run Oauth2 tests
8484
run: |

pulsar/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222

2323
from .schema import Schema, BytesSchema, StringSchema, JsonSchema
2424
from .schema_avro import AvroSchema
25+
from .schema_protobuf import ProtobufNativeSchema

pulsar/schema/schema_protobuf.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
import base64
21+
import _pulsar
22+
23+
from .schema import Schema
24+
25+
try:
26+
from google.protobuf import descriptor_pb2
27+
from google.protobuf.message import Message as ProtobufMessage
28+
HAS_PROTOBUF = True
29+
except ImportError:
30+
HAS_PROTOBUF = False
31+
32+
33+
def _collect_file_descriptors(file_descriptor, visited, file_descriptor_set):
34+
"""Recursively collect all FileDescriptorProto objects into file_descriptor_set."""
35+
if file_descriptor.name in visited:
36+
return
37+
for dep in file_descriptor.dependencies:
38+
_collect_file_descriptors(dep, visited, file_descriptor_set)
39+
visited.add(file_descriptor.name)
40+
proto = descriptor_pb2.FileDescriptorProto()
41+
file_descriptor.CopyToProto(proto)
42+
file_descriptor_set.file.append(proto)
43+
44+
45+
def _build_schema_definition(descriptor):
46+
"""
47+
Build the schema definition dict used by Java's ``ProtobufNativeSchemaData``.
48+
49+
The returned mapping has these keys:
50+
51+
.. code-block:: text
52+
53+
fileDescriptorSet
54+
rootMessageTypeName
55+
rootFileDescriptorName
56+
57+
``fileDescriptorSet`` contains base64-encoded ``FileDescriptorSet`` bytes.
58+
This mirrors ``ProtobufNativeSchemaUtils.serialize()`` in the Java client.
59+
"""
60+
file_descriptor_set = descriptor_pb2.FileDescriptorSet()
61+
_collect_file_descriptors(descriptor.file, set(), file_descriptor_set)
62+
file_descriptor_set_bytes = file_descriptor_set.SerializeToString()
63+
return {
64+
"fileDescriptorSet": base64.b64encode(file_descriptor_set_bytes).decode('utf-8'),
65+
"rootMessageTypeName": descriptor.full_name,
66+
"rootFileDescriptorName": descriptor.file.name,
67+
}
68+
69+
70+
if HAS_PROTOBUF:
71+
class ProtobufNativeSchema(Schema):
72+
"""
73+
Schema for protobuf messages using the native protobuf binary encoding.
74+
75+
The schema definition is stored as a JSON-encoded ProtobufNativeSchemaData
76+
(fileDescriptorSet, rootMessageTypeName, rootFileDescriptorName), which is
77+
compatible with the Java client's ProtobufNativeSchema.
78+
79+
Parameters
80+
----------
81+
record_cls:
82+
A generated protobuf message class (subclass of google.protobuf.message.Message).
83+
84+
Example
85+
-------
86+
.. code-block:: python
87+
88+
import pulsar
89+
from pulsar.schema import ProtobufNativeSchema
90+
from my_proto_pb2 import MyMessage
91+
92+
client = pulsar.Client('pulsar://localhost:6650')
93+
schema = ProtobufNativeSchema(MyMessage)
94+
producer = client.create_producer('my-topic', schema=schema)
95+
consumer = client.subscribe('my-topic', 'my-sub', schema=schema)
96+
97+
message = MyMessage()
98+
message.field = 'value'
99+
producer.send(message)
100+
101+
received = consumer.receive(timeout_millis=5000)
102+
typed_value = received.value()
103+
consumer.acknowledge(received)
104+
105+
assert isinstance(typed_value, MyMessage)
106+
assert typed_value.field == 'value'
107+
108+
consumer.close()
109+
producer.close()
110+
client.close()
111+
"""
112+
113+
def __init__(self, record_cls):
114+
if not (isinstance(record_cls, type) and issubclass(record_cls, ProtobufMessage)):
115+
raise TypeError(
116+
f'record_cls must be a protobuf Message subclass, got {record_cls!r}'
117+
)
118+
schema_definition = _build_schema_definition(record_cls.DESCRIPTOR)
119+
super(ProtobufNativeSchema, self).__init__(
120+
record_cls, _pulsar.SchemaType.PROTOBUF_NATIVE, schema_definition, 'PROTOBUF_NATIVE'
121+
)
122+
123+
def encode(self, obj):
124+
self._validate_object_type(obj)
125+
return obj.SerializeToString()
126+
127+
def decode(self, data):
128+
return self._record_cls.FromString(data)
129+
130+
def __str__(self):
131+
return f'ProtobufNativeSchema({self._record_cls.__name__})'
132+
133+
else:
134+
class ProtobufNativeSchema(Schema):
135+
def __init__(self, _record_cls=None):
136+
raise Exception(
137+
"protobuf library support was not found. "
138+
"Install it with: pip install protobuf"
139+
)
140+
141+
def encode(self, obj):
142+
pass
143+
144+
def decode(self, data):
145+
pass

setup.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,21 @@ def build_extension(self, ext):
7676

7777
extras_require = {}
7878

79+
# protobuf schema dependencies
80+
extras_require["protobuf"] = sorted(
81+
{
82+
"protobuf>=6.33.6",
83+
}
84+
)
85+
7986
# functions dependencies
8087
extras_require["functions"] = sorted(
8188
{
82-
"protobuf>=3.6.1",
8389
"grpcio>=1.59.3",
8490
"apache-bookkeeper-client>=4.16.1",
8591
"prometheus_client",
86-
"ratelimit"
92+
"ratelimit",
93+
*extras_require["protobuf"],
8794
}
8895
)
8996

src/enums.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ void export_enums(py::module_& m) {
115115
.value("AVRO", pulsar::AVRO)
116116
.value("AUTO_CONSUME", pulsar::AUTO_CONSUME)
117117
.value("AUTO_PUBLISH", pulsar::AUTO_PUBLISH)
118-
.value("KEY_VALUE", pulsar::KEY_VALUE);
118+
.value("KEY_VALUE", pulsar::KEY_VALUE)
119+
.value("PROTOBUF_NATIVE", pulsar::PROTOBUF_NATIVE);
119120

120121
enum_<InitialPosition>(m, "InitialPosition", "Supported initial position")
121122
.value("Latest", InitialPositionLatest)

tests/schema_test.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# under the License.
1919
#
2020

21+
import base64
2122
import math
2223
import requests
2324
from typing import List
@@ -29,6 +30,67 @@
2930
from enum import Enum
3031
import json
3132
from fastavro.schema import load_schema
33+
from google.protobuf import descriptor_pb2, descriptor_pool, message_factory
34+
35+
36+
def _add_protobuf_field(message, name, number, field_type, type_name=None):
37+
field = message.field.add()
38+
field.name = name
39+
field.number = number
40+
field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
41+
field.type = field_type
42+
if type_name:
43+
field.type_name = type_name
44+
45+
46+
def _get_message_classes(pool, message_names):
47+
if hasattr(message_factory, 'GetMessageClass'):
48+
return tuple(
49+
message_factory.GetMessageClass(pool.FindMessageTypeByName(message_name))
50+
for message_name in message_names
51+
)
52+
factory = message_factory.MessageFactory(pool)
53+
return tuple(
54+
factory.GetPrototype(pool.FindMessageTypeByName(message_name))
55+
for message_name in message_names
56+
)
57+
58+
59+
def _build_protobuf_test_messages():
60+
file_proto = descriptor_pb2.FileDescriptorProto()
61+
file_proto.name = 'test_schema.proto'
62+
file_proto.package = 'test'
63+
file_proto.syntax = 'proto3'
64+
65+
test_message = file_proto.message_type.add()
66+
test_message.name = 'TestMessage'
67+
_add_protobuf_field(test_message, 'name', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING)
68+
_add_protobuf_field(test_message, 'value', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32)
69+
70+
nested_message = file_proto.message_type.add()
71+
nested_message.name = 'TestMessageWithNested'
72+
_add_protobuf_field(nested_message, 'str_field', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING)
73+
_add_protobuf_field(nested_message, 'int_field', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32)
74+
_add_protobuf_field(nested_message, 'double_field', 3, descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE)
75+
_add_protobuf_field(
76+
nested_message, 'nested', 4, descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, '.test.TestInner'
77+
)
78+
79+
inner_message = file_proto.message_type.add()
80+
inner_message.name = 'TestInner'
81+
_add_protobuf_field(inner_message, 'inner_str', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING)
82+
_add_protobuf_field(inner_message, 'inner_int', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT64)
83+
84+
pool = descriptor_pool.DescriptorPool()
85+
pool.AddSerializedFile(file_proto.SerializeToString())
86+
return _get_message_classes(
87+
pool,
88+
('test.TestMessage', 'test.TestMessageWithNested', 'test.TestInner'),
89+
)
90+
91+
92+
TestMessage, TestMessageWithNested, TestInner = _build_protobuf_test_messages()
93+
3294

3395
class ExampleRecord(Record):
3496
str_field = String()
@@ -1404,5 +1466,90 @@ def test_schema_type_promotion(self):
14041466
client.close()
14051467

14061468

1469+
class ProtobufNativeSchemaTest(TestCase):
1470+
"""Unit tests for ProtobufNativeSchema (no Pulsar broker required)."""
1471+
1472+
def test_schema_type(self):
1473+
"""Schema type must be PROTOBUF_NATIVE."""
1474+
import _pulsar
1475+
schema = ProtobufNativeSchema(TestMessage)
1476+
self.assertEqual(schema.schema_info().schema_type(), _pulsar.SchemaType.PROTOBUF_NATIVE)
1477+
1478+
def test_schema_definition_keys(self):
1479+
"""Schema definition JSON must contain the three required keys."""
1480+
schema = ProtobufNativeSchema(TestMessage)
1481+
schema_def = json.loads(schema.schema_info().schema())
1482+
self.assertIn('fileDescriptorSet', schema_def)
1483+
self.assertIn('rootMessageTypeName', schema_def)
1484+
self.assertIn('rootFileDescriptorName', schema_def)
1485+
1486+
def test_schema_definition_values(self):
1487+
"""rootMessageTypeName and rootFileDescriptorName must match the descriptor."""
1488+
schema = ProtobufNativeSchema(TestMessage)
1489+
schema_def = json.loads(schema.schema_info().schema())
1490+
self.assertEqual(schema_def['rootMessageTypeName'], 'test.TestMessage')
1491+
self.assertEqual(schema_def['rootFileDescriptorName'], 'test_schema.proto')
1492+
1493+
def test_file_descriptor_set_is_valid_base64_proto(self):
1494+
"""fileDescriptorSet must be valid base64-encoded FileDescriptorSet bytes."""
1495+
from google.protobuf import descriptor_pb2
1496+
schema = ProtobufNativeSchema(TestMessage)
1497+
schema_def = json.loads(schema.schema_info().schema())
1498+
raw = base64.b64decode(schema_def['fileDescriptorSet'])
1499+
fds = descriptor_pb2.FileDescriptorSet.FromString(raw)
1500+
file_names = [f.name for f in fds.file]
1501+
self.assertIn('test_schema.proto', file_names)
1502+
1503+
def test_encode_decode_roundtrip(self):
1504+
"""encode then decode must reproduce the original message."""
1505+
schema = ProtobufNativeSchema(TestMessage)
1506+
original = TestMessage(name='hello', value=42)
1507+
encoded = schema.encode(original)
1508+
decoded = schema.decode(encoded)
1509+
self.assertEqual(decoded.name, 'hello')
1510+
self.assertEqual(decoded.value, 42)
1511+
1512+
def test_encode_produces_protobuf_binary(self):
1513+
"""Encoded bytes must be valid protobuf binary (parseable by the class directly)."""
1514+
schema = ProtobufNativeSchema(TestMessage)
1515+
msg = TestMessage(name='pulsar', value=100)
1516+
encoded = schema.encode(msg)
1517+
# Verify with protobuf's own parser
1518+
reparsed = TestMessage.FromString(encoded)
1519+
self.assertEqual(reparsed, msg)
1520+
1521+
def test_encode_decode_nested_message(self):
1522+
"""encode/decode round-trip works for messages containing nested message fields."""
1523+
schema = ProtobufNativeSchema(TestMessageWithNested)
1524+
original = TestMessageWithNested(
1525+
str_field='test',
1526+
int_field=7,
1527+
double_field=3.14,
1528+
nested=TestInner(inner_str='inner', inner_int=999),
1529+
)
1530+
decoded = schema.decode(schema.encode(original))
1531+
self.assertEqual(decoded.str_field, 'test')
1532+
self.assertEqual(decoded.int_field, 7)
1533+
self.assertAlmostEqual(decoded.double_field, 3.14)
1534+
self.assertEqual(decoded.nested.inner_str, 'inner')
1535+
self.assertEqual(decoded.nested.inner_int, 999)
1536+
1537+
def test_wrong_type_raises(self):
1538+
"""Encoding an object of the wrong type must raise TypeError."""
1539+
schema = ProtobufNativeSchema(TestMessage)
1540+
with self.assertRaises(TypeError):
1541+
schema.encode("not a protobuf message")
1542+
1543+
def test_non_message_class_raises(self):
1544+
"""Constructing with a non-Message class must raise TypeError."""
1545+
with self.assertRaises(TypeError):
1546+
ProtobufNativeSchema(str)
1547+
1548+
def test_schema_name(self):
1549+
"""Schema name must be 'PROTOBUF_NATIVE'."""
1550+
schema = ProtobufNativeSchema(TestMessage)
1551+
self.assertEqual(schema.schema_info().name(), 'PROTOBUF_NATIVE')
1552+
1553+
14071554
if __name__ == '__main__':
14081555
main()

0 commit comments

Comments
 (0)