55
66from abc import ABC
77from collections import namedtuple
8- from typing import Mapping , Union
8+ from typing import Literal , Mapping , Optional , Type , TypeVar , Union , cast , overload
99
1010from marshmallow import Schema , post_dump , pre_load , post_load , ValidationError , EXCLUDE
1111
1717SerDe = namedtuple ("SerDe" , "ser de" )
1818
1919
20- def resolve_class (the_cls , relative_cls : type = None ):
20+ def resolve_class (the_cls , relative_cls : Optional [ type ] = None ) -> type :
2121 """
2222 Resolve a class.
2323
@@ -38,6 +38,10 @@ def resolve_class(the_cls, relative_cls: type = None):
3838 elif isinstance (the_cls , str ):
3939 default_module = relative_cls and relative_cls .__module__
4040 resolved = ClassLoader .load_class (the_cls , default_module )
41+ else :
42+ raise TypeError (
43+ f"Could not resolve class from { the_cls } ; incorrect type { type (the_cls )} "
44+ )
4145 return resolved
4246
4347
@@ -70,6 +74,9 @@ class BaseModelError(BaseError):
7074 """Base exception class for base model errors."""
7175
7276
77+ ModelType = TypeVar ("ModelType" , bound = "BaseModel" )
78+
79+
7380class BaseModel (ABC ):
7481 """Base model that provides convenience methods."""
7582
@@ -94,18 +101,24 @@ def __init__(self):
94101 )
95102
96103 @classmethod
97- def _get_schema_class (cls ):
104+ def _get_schema_class (cls ) -> Type [ "BaseModelSchema" ] :
98105 """
99106 Get the schema class.
100107
101108 Returns:
102109 The resolved schema class
103110
104111 """
105- return resolve_class (cls .Meta .schema_class , cls )
112+ resolved = resolve_class (cls .Meta .schema_class , cls )
113+ if issubclass (resolved , BaseModelSchema ):
114+ return resolved
115+
116+ raise TypeError (
117+ f"Resolved class is not a subclass of BaseModelSchema: { resolved } "
118+ )
106119
107120 @property
108- def Schema (self ) -> type :
121+ def Schema (self ) -> Type [ "BaseModelSchema" ] :
109122 """
110123 Accessor for the model's schema class.
111124
@@ -115,8 +128,46 @@ def Schema(self) -> type:
115128 """
116129 return self ._get_schema_class ()
117130
131+ @overload
118132 @classmethod
119- def deserialize (cls , obj , unknown : str = None , none2none : str = False ):
133+ def deserialize (
134+ cls : Type [ModelType ],
135+ obj ,
136+ * ,
137+ unknown : Optional [str ] = None ,
138+ ) -> ModelType :
139+ ...
140+
141+ @overload
142+ @classmethod
143+ def deserialize (
144+ cls : Type [ModelType ],
145+ obj ,
146+ * ,
147+ none2none : Literal [False ],
148+ unknown : Optional [str ] = None ,
149+ ) -> ModelType :
150+ ...
151+
152+ @overload
153+ @classmethod
154+ def deserialize (
155+ cls : Type [ModelType ],
156+ obj ,
157+ * ,
158+ none2none : Literal [True ],
159+ unknown : Optional [str ] = None ,
160+ ) -> Optional [ModelType ]:
161+ ...
162+
163+ @classmethod
164+ def deserialize (
165+ cls : Type [ModelType ],
166+ obj ,
167+ * ,
168+ unknown : Optional [str ] = None ,
169+ none2none : bool = False ,
170+ ) -> Optional [ModelType ]:
120171 """
121172 Convert from JSON representation to a model instance.
122173
@@ -132,18 +183,41 @@ def deserialize(cls, obj, unknown: str = None, none2none: str = False):
132183 if obj is None and none2none :
133184 return None
134185
135- schema = cls ._get_schema_class ()(unknown = unknown or EXCLUDE )
186+ schema_cls = cls ._get_schema_class ()
187+ schema = schema_cls (unknown = unknown or schema_cls .Meta .unknown )
188+
136189 try :
137- return schema .loads (obj ) if isinstance (obj , str ) else schema .load (obj )
190+ return cast (
191+ ModelType ,
192+ schema .loads (obj ) if isinstance (obj , str ) else schema .load (obj ),
193+ )
138194 except (AttributeError , ValidationError ) as err :
139195 LOGGER .exception (f"{ cls .__name__ } message validation error:" )
140196 raise BaseModelError (f"{ cls .__name__ } schema validation failed" ) from err
141197
198+ @overload
142199 def serialize (
143200 self ,
144- as_string = False ,
145- unknown : str = None ,
201+ * ,
202+ as_string : Literal [True ],
203+ unknown : Optional [str ] = None ,
204+ ) -> str :
205+ ...
206+
207+ @overload
208+ def serialize (
209+ self ,
210+ * ,
211+ unknown : Optional [str ] = None ,
146212 ) -> dict :
213+ ...
214+
215+ def serialize (
216+ self ,
217+ * ,
218+ as_string : bool = False ,
219+ unknown : Optional [str ] = None ,
220+ ) -> Union [str , dict ]:
147221 """
148222 Create a JSON-compatible dict representation of the model instance.
149223
@@ -154,7 +228,8 @@ def serialize(
154228 A dict representation of this model, or a JSON string if as_string is True
155229
156230 """
157- schema = self .Schema (unknown = unknown or EXCLUDE )
231+ schema_cls = self ._get_schema_class ()
232+ schema = schema_cls (unknown = unknown or schema_cls .Meta .unknown )
158233 try :
159234 return (
160235 schema .dumps (self , separators = ("," , ":" ))
@@ -168,18 +243,17 @@ def serialize(
168243 ) from err
169244
170245 @classmethod
171- def serde (cls , obj : Union ["BaseModel" , Mapping ]) -> SerDe :
246+ def serde (cls , obj : Union ["BaseModel" , Mapping ]) -> Optional [ SerDe ] :
172247 """Return serialized, deserialized representations of input object."""
248+ if obj is None :
249+ return None
173250
174- return (
175- SerDe (obj .serialize (), obj )
176- if isinstance (obj , BaseModel )
177- else None
178- if obj is None
179- else SerDe (obj , cls .deserialize (obj ))
180- )
251+ if isinstance (obj , BaseModel ):
252+ return SerDe (obj .serialize (), obj )
253+
254+ return SerDe (obj , cls .deserialize (obj ))
181255
182- def validate (self , unknown : str = None ):
256+ def validate (self , unknown : Optional [ str ] = None ):
183257 """Validate a constructed model."""
184258 schema = self .Schema (unknown = unknown )
185259 errors = schema .validate (self .serialize ())
@@ -191,7 +265,7 @@ def validate(self, unknown: str = None):
191265 def from_json (
192266 cls ,
193267 json_repr : Union [str , bytes ],
194- unknown : str = None ,
268+ unknown : Optional [ str ] = None ,
195269 ):
196270 """
197271 Parse a JSON string into a model instance.
@@ -218,7 +292,7 @@ def to_json(self, unknown: str = None) -> str:
218292 A JSON representation of this message
219293
220294 """
221- return json .dumps (self .serialize (unknown = unknown or EXCLUDE ))
295+ return json .dumps (self .serialize (unknown = unknown ))
222296
223297 def __repr__ (self ) -> str :
224298 """
@@ -246,6 +320,7 @@ class Meta:
246320 model_class = None
247321 skip_values = [None ]
248322 ordered = True
323+ unknown = EXCLUDE
249324
250325 def __init__ (self , * args , ** kwargs ):
251326 """
0 commit comments