Skip to content

Commit 0de5904

Browse files
committed
test: schema meta unknown respected
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
1 parent b6c55dc commit 0de5904

2 files changed

Lines changed: 76 additions & 15 deletions

File tree

aries_cloudagent/messaging/models/base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from abc import ABC
77
from collections import namedtuple
8-
from typing import Literal, Mapping, Optional, Type, TypeVar, Union, cast, overload
8+
from typing import Mapping, Optional, Type, TypeVar, Union, cast, overload
9+
from typing_extensions import Literal
910

1011
from marshmallow import Schema, post_dump, pre_load, post_load, ValidationError, EXCLUDE
1112

@@ -57,7 +58,10 @@ def resolve_meta_property(obj, prop_name: str, defval=None):
5758
The meta property
5859
5960
"""
60-
cls = obj.__class__
61+
if isinstance(obj, type):
62+
cls = obj
63+
else:
64+
cls = obj.__class__
6165
found = defval
6266
while cls:
6367
Meta = getattr(cls, "Meta", None)
@@ -184,7 +188,9 @@ def deserialize(
184188
return None
185189

186190
schema_cls = cls._get_schema_class()
187-
schema = schema_cls(unknown=unknown or schema_cls.Meta.unknown)
191+
schema = schema_cls(
192+
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
193+
)
188194

189195
try:
190196
return cast(
@@ -229,7 +235,9 @@ def serialize(
229235
230236
"""
231237
schema_cls = self._get_schema_class()
232-
schema = schema_cls(unknown=unknown or schema_cls.Meta.unknown)
238+
schema = schema_cls(
239+
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
240+
)
233241
try:
234242
return (
235243
schema.dumps(self, separators=(",", ":"))
@@ -320,7 +328,6 @@ class Meta:
320328
model_class = None
321329
skip_values = [None]
322330
ordered = True
323-
unknown = EXCLUDE
324331

325332
def __init__(self, *args, **kwargs):
326333
"""

aries_cloudagent/messaging/models/tests/test_base.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
1-
import json
2-
31
from asynctest import TestCase as AsyncTestCase, mock as async_mock
42

5-
from marshmallow import EXCLUDE, fields, validates_schema, ValidationError
6-
7-
from ....cache.base import BaseCache
8-
from ....config.injection_context import InjectionContext
9-
from ....storage.base import BaseStorage, StorageRecord
10-
11-
from ...responder import BaseResponder, MockResponder
12-
from ...util import time_now
3+
from marshmallow import EXCLUDE, INCLUDE, fields, validates_schema, ValidationError
134

145
from ..base import BaseModel, BaseModelError, BaseModelSchema
156

@@ -35,6 +26,48 @@ def validate_fields(self, data, **kwargs):
3526
raise ValidationError("")
3627

3728

29+
class ModelImplWithUnknown(BaseModel):
30+
class Meta:
31+
schema_class = "SchemaImplWithUnknown"
32+
33+
def __init__(self, *, attr=None, **kwargs):
34+
self.attr = attr
35+
self.extra = kwargs
36+
37+
38+
class SchemaImplWithUnknown(BaseModelSchema):
39+
class Meta:
40+
model_class = ModelImplWithUnknown
41+
unknown = INCLUDE
42+
43+
attr = fields.String(required=True)
44+
45+
@validates_schema
46+
def validate_fields(self, data, **kwargs):
47+
if data["attr"] != "succeeds":
48+
raise ValidationError("")
49+
50+
51+
class ModelImplWithoutUnknown(BaseModel):
52+
class Meta:
53+
schema_class = "SchemaImplWithoutUnknown"
54+
55+
def __init__(self, *, attr=None):
56+
self.attr = attr
57+
58+
59+
class SchemaImplWithoutUnknown(BaseModelSchema):
60+
class Meta:
61+
model_class = ModelImplWithoutUnknown
62+
63+
attr = fields.String(required=True)
64+
65+
@validates_schema
66+
def validate_fields(self, data, **kwargs):
67+
if data["attr"] != "succeeds":
68+
raise ValidationError("")
69+
70+
3871
class TestBase(AsyncTestCase):
3972
def test_model_validate_fails(self):
4073
model = ModelImpl(attr="string")
@@ -63,3 +96,24 @@ def test_from_json_x(self):
6396
data = "{}{}"
6497
with self.assertRaises(BaseModelError):
6598
ModelImpl.from_json(data)
99+
100+
def test_model_with_unknown(self):
101+
model = ModelImplWithUnknown(attr="succeeds")
102+
model = model.validate()
103+
assert model.attr == "succeeds"
104+
105+
model = ModelImplWithUnknown.deserialize(
106+
{"attr": "succeeds", "another": "value"}
107+
)
108+
assert model.extra
109+
assert model.extra["another"] == "value"
110+
assert model.attr == "succeeds"
111+
112+
def test_model_without_unknown_default_exclude(self):
113+
model = ModelImplWithoutUnknown(attr="succeeds")
114+
model = model.validate()
115+
assert model.attr == "succeeds"
116+
117+
assert ModelImplWithoutUnknown.deserialize(
118+
{"attr": "succeeds", "another": "value"}
119+
)

0 commit comments

Comments
 (0)