Skip to content

Commit 6f2ef55

Browse files
authored
Merge pull request openwallet-foundation#2060 from andrewwhitehead/fix/accept-unknown-hsproto
Do not reject OOB invitation with unknown handshake protocol(s)
2 parents 8a80f71 + 5aac8c0 commit 6f2ef55

4 files changed

Lines changed: 27 additions & 49 deletions

File tree

aries_cloudagent/messaging/valid.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,60 +23,41 @@
2323
class StrOrDictField(Field):
2424
"""URI or Dict field for Marshmallow."""
2525

26-
def _serialize(self, value, attr, obj, **kwargs):
27-
return value
28-
2926
def _deserialize(self, value, attr, data, **kwargs):
30-
if isinstance(value, (str, dict)):
31-
return value
32-
else:
27+
if not isinstance(value, (str, dict)):
3328
raise ValidationError("Field should be str or dict")
29+
return super()._deserialize(value, attr, data, **kwargs)
3430

3531

3632
class StrOrNumberField(Field):
3733
"""String or Number field for Marshmallow."""
3834

39-
def _serialize(self, value, attr, obj, **kwargs):
40-
return value
41-
4235
def _deserialize(self, value, attr, data, **kwargs):
43-
if isinstance(value, (str, float, int)):
44-
return value
45-
else:
36+
if not isinstance(value, (str, float, int)):
4637
raise ValidationError("Field should be str or int or float")
38+
return super()._deserialize(value, attr, data, **kwargs)
4739

4840

4941
class DictOrDictListField(Field):
5042
"""Dict or Dict List field for Marshmallow."""
5143

52-
def _serialize(self, value, attr, obj, **kwargs):
53-
return value
54-
5544
def _deserialize(self, value, attr, data, **kwargs):
56-
# dict
57-
if isinstance(value, dict):
58-
return value
59-
# list of dicts
60-
elif isinstance(value, list) and all(isinstance(item, dict) for item in value):
61-
return value
62-
else:
63-
raise ValidationError("Field should be dict or list of dicts")
45+
if not isinstance(value, dict):
46+
if not isinstance(value, list) or not all(
47+
isinstance(item, dict) for item in value
48+
):
49+
raise ValidationError("Field should be dict or list of dicts")
50+
return super()._deserialize(value, attr, data, **kwargs)
6451

6552

6653
class UriOrDictField(StrOrDictField):
6754
"""URI or Dict field for Marshmallow."""
6855

69-
def __init__(self, *args, **kwargs):
70-
"""Initialize new UriOrDictField instance."""
71-
super().__init__(*args, **kwargs)
72-
73-
# Insert validation into self.validators so that multiple errors can be stored.
74-
self.validators.insert(0, self._uri_validator)
75-
76-
def _uri_validator(self, value):
77-
# Check if URI when
56+
def _deserialize(self, value, attr, data, **kwargs):
7857
if isinstance(value, str):
79-
return Uri()(value)
58+
# Check regex
59+
Uri()(value)
60+
return super()._deserialize(value, attr, data, **kwargs)
8061

8162

8263
class IntEpoch(Range):
@@ -775,7 +756,7 @@ def __call__(self, value):
775756
except ValidationError:
776757
raise ValidationError(
777758
f"credential subject id {value[0]} must be URI"
778-
)
759+
) from None
779760

780761
return value
781762

aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _serialize(self, value, attr, obj, **kwargs):
2323
"""
2424
return value.serialize()
2525

26-
def _deserialize(self, value, attr, data, **kwargs):
26+
def _deserialize(self, value, attr=None, data=None, **kwargs):
2727
"""
2828
Deserialize a value into a DIDDoc.
2929

aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,17 @@ def _serialize(self, value, attr, obj, **kwargs):
9696
def _deserialize(self, value, attr, data, **kwargs):
9797
if isinstance(value, dict):
9898
return Service.deserialize(value)
99+
elif isinstance(value, Service):
100+
return value
99101
elif isinstance(value, str):
100-
if bool(DIDValidation.PATTERN.match(value)):
101-
return value
102-
else:
102+
if not DIDValidation.PATTERN.match(value):
103103
raise ValidationError(
104104
"Service item must be a valid decentralized identifier (DID)"
105105
)
106+
return value
107+
raise ValidationError(
108+
"Service item must be a valid decentralized identifier (DID) or object"
109+
)
106110

107111

108112
class InvitationMessage(AgentMessage):
@@ -221,9 +225,6 @@ class Meta:
221225
fields.Str(
222226
description="Handshake protocol",
223227
example=DIDCommPrefix.qualify_current(HSProto.RFC23.name),
224-
validate=lambda hsp: (
225-
DIDCommPrefix.unqualify(hsp) in [p.name for p in HSProto]
226-
),
227228
),
228229
required=False,
229230
)
@@ -276,13 +277,10 @@ def validate_fields(self, data, **kwargs):
276277
"""
277278
handshake_protocols = data.get("handshake_protocols")
278279
requests_attach = data.get("requests_attach")
279-
if not (
280-
(handshake_protocols and len(handshake_protocols) > 0)
281-
or (requests_attach and len(requests_attach) > 0)
282-
):
280+
if not handshake_protocols and not requests_attach:
283281
raise ValidationError(
284282
"Model must include non-empty "
285-
"handshake_protocols or requests_attach or both"
283+
"handshake_protocols or requests~attach or both"
286284
)
287285

288286
# services = data.get("services")

aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,8 @@ def test_invalid_invi_wrong_type_services(self):
139139
"services": [123],
140140
}
141141

142-
invi_schema = InvitationMessageSchema()
143-
with pytest.raises(test_module.ValidationError):
144-
invi_schema.validate_fields(obj_x)
142+
errs = InvitationMessageSchema().validate(obj_x)
143+
assert errs and "services" in errs
145144

146145
def test_assign_msg_type_version_to_model_inst(self):
147146
test_msg = InvitationMessage()

0 commit comments

Comments
 (0)