Skip to content

Commit 960aa91

Browse files
authored
Merge pull request openwallet-foundation#1970 from shaangill025/mediation_fix
Fix: `--mediator-invitation` with OOB invitation + cleanup
2 parents 9dc2fa3 + 42b5029 commit 960aa91

5 files changed

Lines changed: 42 additions & 18 deletions

File tree

aries_cloudagent/core/conductor.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..config.logging import LoggingConfigurator
2727
from ..config.provider import ClassProvider
2828
from ..config.wallet import wallet_config
29+
from ..connections.models.conn_record import ConnRecord
2930
from ..core.profile import Profile
3031
from ..indy.verifier import IndyVerifier
3132
from ..ledger.base import BaseLedger
@@ -450,14 +451,23 @@ async def start(self) -> None:
450451
if mediation_connections_invite
451452
else OutOfBandManager(self.root_profile)
452453
)
453-
454-
conn_record = await mgr.receive_invitation(
455-
invitation=invitation_handler.from_url(
456-
mediation_invite_record.invite
457-
),
458-
auto_accept=True,
459-
)
460454
async with self.root_profile.session() as session:
455+
invitation = invitation_handler.from_url(
456+
mediation_invite_record.invite
457+
)
458+
if isinstance(mgr, OutOfBandManager):
459+
oob_record = await mgr.receive_invitation(
460+
invitation=invitation,
461+
auto_accept=True,
462+
)
463+
conn_record = await ConnRecord.retrieve_by_id(
464+
session, oob_record.connection_id
465+
)
466+
else:
467+
conn_record = await mgr.receive_invitation(
468+
invitation=invitation,
469+
auto_accept=True,
470+
)
461471
await (
462472
MediationInviteStore(
463473
session.context.inject(BaseStorage)

aries_cloudagent/core/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ async def make_message(
296296
if not isinstance(parsed_msg, dict):
297297
raise MessageParseError("Expected a JSON object")
298298
message_type = parsed_msg.get("@type")
299-
message_type_rec_version = get_version_from_message_type(message_type)
300299

301300
if not message_type:
302301
raise MessageParseError("Message does not contain '@type' parameter")
302+
message_type_rec_version = get_version_from_message_type(message_type)
303303

304304
registry: ProtocolRegistry = self.profile.inject(ProtocolRegistry)
305305
try:

aries_cloudagent/core/tests/test_conductor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,9 @@ async def test_mediator_invitation_0434(self, mock_from_url, _):
11611161
"test": async_mock.MagicMock(schemes=["http"])
11621162
}
11631163
await conductor.setup()
1164-
1164+
conductor.root_profile.context.update_settings(
1165+
{"mediation.connections_invite": False}
1166+
)
11651167
conn_record = ConnRecord(
11661168
invitation_key="3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx",
11671169
their_label="Hello",
@@ -1170,12 +1172,15 @@ async def test_mediator_invitation_0434(self, mock_from_url, _):
11701172
)
11711173
conn_record.accept = ConnRecord.ACCEPT_MANUAL
11721174
await conn_record.save(await conductor.root_profile.session())
1175+
oob_record = async_mock.MagicMock(
1176+
connection_id=conn_record.connection_id,
1177+
)
11731178
with async_mock.patch.object(
11741179
test_module,
11751180
"OutOfBandManager",
11761181
async_mock.MagicMock(
11771182
return_value=async_mock.MagicMock(
1178-
receive_invitation=async_mock.AsyncMock(return_value=conn_record)
1183+
receive_invitation=async_mock.AsyncMock(return_value=oob_record)
11791184
)
11801185
),
11811186
) as mock_mgr, async_mock.patch.object(
@@ -1185,10 +1190,10 @@ async def test_mediator_invitation_0434(self, mock_from_url, _):
11851190
return_value=async_mock.MagicMock(value=f"v{__version__}")
11861191
),
11871192
):
1193+
assert not conductor.root_profile.settings["mediation.connections_invite"]
11881194
await conductor.start()
11891195
await conductor.stop()
11901196
mock_from_url.assert_called_once_with("test-invite")
1191-
mock_mgr.return_value.receive_invitation.assert_called_once()
11921197

11931198
@async_mock.patch.object(test_module, "MediationInviteStore")
11941199
@async_mock.patch.object(test_module.ConnectionInvitation, "from_url")

aries_cloudagent/protocols/connections/v1_0/routes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
request_schema,
1111
response_schema,
1212
)
13-
13+
from typing import cast
1414
from marshmallow import fields, validate, validates_schema
1515

1616
from ....admin.request_context import AdminRequestContext
@@ -115,7 +115,7 @@ class CreateInvitationRequestSchema(OpenAPISchema):
115115
mediation_id = fields.Str(
116116
required=False,
117117
description="Identifier for active mediation record to be used",
118-
**MEDIATION_ID_SCHEMA
118+
**MEDIATION_ID_SCHEMA,
119119
)
120120

121121

@@ -247,7 +247,7 @@ class ReceiveInvitationQueryStringSchema(OpenAPISchema):
247247
mediation_id = fields.Str(
248248
required=False,
249249
description="Identifier for active mediation record to be used",
250-
**MEDIATION_ID_SCHEMA
250+
**MEDIATION_ID_SCHEMA,
251251
)
252252

253253

@@ -261,7 +261,7 @@ class AcceptInvitationQueryStringSchema(OpenAPISchema):
261261
mediation_id = fields.Str(
262262
required=False,
263263
description="Identifier for active mediation record to be used",
264-
**MEDIATION_ID_SCHEMA
264+
**MEDIATION_ID_SCHEMA,
265265
)
266266

267267

@@ -536,11 +536,16 @@ async def connections_create_invitation(request: web.BaseRequest):
536536
metadata=metadata,
537537
mediation_id=mediation_id,
538538
)
539-
539+
invitation_url = invitation.to_url(base_url)
540+
base_endpoint = service_endpoint or cast(
541+
str, profile.settings.get("default_endpoint")
542+
)
540543
result = {
541544
"connection_id": connection and connection.connection_id,
542545
"invitation": invitation.serialize(),
543-
"invitation_url": invitation.to_url(base_url),
546+
"invitation_url": f"{base_endpoint}{invitation_url}"
547+
if invitation_url.startswith("?")
548+
else invitation_url,
544549
}
545550
except (ConnectionManagerError, StorageError, BaseModelError) as err:
546551
raise web.HTTPBadRequest(reason=err.roll_up) from err

aries_cloudagent/protocols/endorse_transaction/v1_0/routes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,11 +764,15 @@ async def on_startup_event(profile: Profile, event: Event):
764764
invite = InvitationMessage.from_url(endorser_invitation)
765765
if invite:
766766
oob_mgr = OutOfBandManager(profile)
767-
conn_record = await oob_mgr.receive_invitation(
767+
oob_record = await oob_mgr.receive_invitation(
768768
invitation=invite,
769769
auto_accept=True,
770770
alias=endorser_alias,
771771
)
772+
async with profile.session() as session:
773+
conn_record = await ConnRecord.retrieve_by_id(
774+
session, oob_record.connection_id
775+
)
772776
else:
773777
invite = ConnectionInvitation.from_url(endorser_invitation)
774778
if invite:

0 commit comments

Comments
 (0)