Skip to content

Commit b6c55dc

Browse files
committed
fix: schema class can set Meta.unknown
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
1 parent 2b4c307 commit b6c55dc

2 files changed

Lines changed: 102 additions & 24 deletions

File tree

aries_cloudagent/messaging/models/base.py

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from abc import ABC
77
from collections import namedtuple
8-
from typing import Mapping, Union
8+
from typing import Literal, Mapping, Optional, Type, TypeVar, Union, cast, overload
99

1010
from marshmallow import Schema, post_dump, pre_load, post_load, ValidationError, EXCLUDE
1111

@@ -17,7 +17,7 @@
1717
SerDe = namedtuple("SerDe", "ser de")
1818

1919

20-
def resolve_class(the_cls, relative_cls: type = None):
20+
def resolve_class(the_cls, relative_cls: Optional[type] = None) -> type:
2121
"""
2222
Resolve a class.
2323
@@ -38,6 +38,10 @@ def resolve_class(the_cls, relative_cls: type = None):
3838
elif isinstance(the_cls, str):
3939
default_module = relative_cls and relative_cls.__module__
4040
resolved = ClassLoader.load_class(the_cls, default_module)
41+
else:
42+
raise TypeError(
43+
f"Could not resolve class from {the_cls}; incorrect type {type(the_cls)}"
44+
)
4145
return resolved
4246

4347

@@ -70,6 +74,9 @@ class BaseModelError(BaseError):
7074
"""Base exception class for base model errors."""
7175

7276

77+
ModelType = TypeVar("ModelType", bound="BaseModel")
78+
79+
7380
class BaseModel(ABC):
7481
"""Base model that provides convenience methods."""
7582

@@ -94,18 +101,24 @@ def __init__(self):
94101
)
95102

96103
@classmethod
97-
def _get_schema_class(cls):
104+
def _get_schema_class(cls) -> Type["BaseModelSchema"]:
98105
"""
99106
Get the schema class.
100107
101108
Returns:
102109
The resolved schema class
103110
104111
"""
105-
return resolve_class(cls.Meta.schema_class, cls)
112+
resolved = resolve_class(cls.Meta.schema_class, cls)
113+
if issubclass(resolved, BaseModelSchema):
114+
return resolved
115+
116+
raise TypeError(
117+
f"Resolved class is not a subclass of BaseModelSchema: {resolved}"
118+
)
106119

107120
@property
108-
def Schema(self) -> type:
121+
def Schema(self) -> Type["BaseModelSchema"]:
109122
"""
110123
Accessor for the model's schema class.
111124
@@ -115,8 +128,46 @@ def Schema(self) -> type:
115128
"""
116129
return self._get_schema_class()
117130

131+
@overload
118132
@classmethod
119-
def deserialize(cls, obj, unknown: str = None, none2none: str = False):
133+
def deserialize(
134+
cls: Type[ModelType],
135+
obj,
136+
*,
137+
unknown: Optional[str] = None,
138+
) -> ModelType:
139+
...
140+
141+
@overload
142+
@classmethod
143+
def deserialize(
144+
cls: Type[ModelType],
145+
obj,
146+
*,
147+
none2none: Literal[False],
148+
unknown: Optional[str] = None,
149+
) -> ModelType:
150+
...
151+
152+
@overload
153+
@classmethod
154+
def deserialize(
155+
cls: Type[ModelType],
156+
obj,
157+
*,
158+
none2none: Literal[True],
159+
unknown: Optional[str] = None,
160+
) -> Optional[ModelType]:
161+
...
162+
163+
@classmethod
164+
def deserialize(
165+
cls: Type[ModelType],
166+
obj,
167+
*,
168+
unknown: Optional[str] = None,
169+
none2none: bool = False,
170+
) -> Optional[ModelType]:
120171
"""
121172
Convert from JSON representation to a model instance.
122173
@@ -132,18 +183,41 @@ def deserialize(cls, obj, unknown: str = None, none2none: str = False):
132183
if obj is None and none2none:
133184
return None
134185

135-
schema = cls._get_schema_class()(unknown=unknown or EXCLUDE)
186+
schema_cls = cls._get_schema_class()
187+
schema = schema_cls(unknown=unknown or schema_cls.Meta.unknown)
188+
136189
try:
137-
return schema.loads(obj) if isinstance(obj, str) else schema.load(obj)
190+
return cast(
191+
ModelType,
192+
schema.loads(obj) if isinstance(obj, str) else schema.load(obj),
193+
)
138194
except (AttributeError, ValidationError) as err:
139195
LOGGER.exception(f"{cls.__name__} message validation error:")
140196
raise BaseModelError(f"{cls.__name__} schema validation failed") from err
141197

198+
@overload
142199
def serialize(
143200
self,
144-
as_string=False,
145-
unknown: str = None,
201+
*,
202+
as_string: Literal[True],
203+
unknown: Optional[str] = None,
204+
) -> str:
205+
...
206+
207+
@overload
208+
def serialize(
209+
self,
210+
*,
211+
unknown: Optional[str] = None,
146212
) -> dict:
213+
...
214+
215+
def serialize(
216+
self,
217+
*,
218+
as_string: bool = False,
219+
unknown: Optional[str] = None,
220+
) -> Union[str, dict]:
147221
"""
148222
Create a JSON-compatible dict representation of the model instance.
149223
@@ -154,7 +228,8 @@ def serialize(
154228
A dict representation of this model, or a JSON string if as_string is True
155229
156230
"""
157-
schema = self.Schema(unknown=unknown or EXCLUDE)
231+
schema_cls = self._get_schema_class()
232+
schema = schema_cls(unknown=unknown or schema_cls.Meta.unknown)
158233
try:
159234
return (
160235
schema.dumps(self, separators=(",", ":"))
@@ -168,18 +243,17 @@ def serialize(
168243
) from err
169244

170245
@classmethod
171-
def serde(cls, obj: Union["BaseModel", Mapping]) -> SerDe:
246+
def serde(cls, obj: Union["BaseModel", Mapping]) -> Optional[SerDe]:
172247
"""Return serialized, deserialized representations of input object."""
248+
if obj is None:
249+
return None
173250

174-
return (
175-
SerDe(obj.serialize(), obj)
176-
if isinstance(obj, BaseModel)
177-
else None
178-
if obj is None
179-
else SerDe(obj, cls.deserialize(obj))
180-
)
251+
if isinstance(obj, BaseModel):
252+
return SerDe(obj.serialize(), obj)
253+
254+
return SerDe(obj, cls.deserialize(obj))
181255

182-
def validate(self, unknown: str = None):
256+
def validate(self, unknown: Optional[str] = None):
183257
"""Validate a constructed model."""
184258
schema = self.Schema(unknown=unknown)
185259
errors = schema.validate(self.serialize())
@@ -191,7 +265,7 @@ def validate(self, unknown: str = None):
191265
def from_json(
192266
cls,
193267
json_repr: Union[str, bytes],
194-
unknown: str = None,
268+
unknown: Optional[str] = None,
195269
):
196270
"""
197271
Parse a JSON string into a model instance.
@@ -218,7 +292,7 @@ def to_json(self, unknown: str = None) -> str:
218292
A JSON representation of this message
219293
220294
"""
221-
return json.dumps(self.serialize(unknown=unknown or EXCLUDE))
295+
return json.dumps(self.serialize(unknown=unknown))
222296

223297
def __repr__(self) -> str:
224298
"""
@@ -246,6 +320,7 @@ class Meta:
246320
model_class = None
247321
skip_values = [None]
248322
ordered = True
323+
unknown = EXCLUDE
249324

250325
def __init__(self, *args, **kwargs):
251326
"""

aries_cloudagent/utils/classloader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from importlib import import_module
88
from importlib.util import find_spec, resolve_name
99
from types import ModuleType
10-
from typing import Sequence, Type
10+
from typing import Optional, Sequence, Type
1111

1212
from ..core.error import BaseError
1313

@@ -75,7 +75,10 @@ def load_module(cls, mod_path: str, package: str = None) -> ModuleType:
7575

7676
@classmethod
7777
def load_class(
78-
cls, class_name: str, default_module: str = None, package: str = None
78+
cls,
79+
class_name: str,
80+
default_module: Optional[str] = None,
81+
package: Optional[str] = None,
7982
):
8083
"""
8184
Resolve a complete class path (ie. typing.Dict) to the class itself.

0 commit comments

Comments
 (0)