Skip to content

Commit 4939c59

Browse files
committed
fix: actually return mediated conn ids
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
1 parent 98cfa03 commit 4939c59

2 files changed

Lines changed: 55 additions & 38 deletions

File tree

aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from ....core.error import BaseError
88
from ....core.profile import Profile, ProfileSession
9+
from ....connections.models.conn_record import ConnRecord
910
from ....storage.base import BaseStorage
1011
from ....storage.error import StorageNotFoundError
1112
from ....storage.record import StorageRecord
1213
from ....wallet.base import BaseWallet
1314
from ....wallet.did_info import DIDInfo
1415
from ....wallet.did_method import DIDMethod
16+
from ....wallet.error import WalletNotFoundError
1517
from ....wallet.key_type import KeyType
1618
from ...routing.v1_0.manager import RoutingManager
1719
from ...routing.v1_0.models.route_record import RouteRecord
@@ -600,21 +602,34 @@ async def store_update_results(
600602
for record_for_removal in to_remove:
601603
await record_for_removal.delete_record(session)
602604

605+
async def _conn_id_from_recipient_key(
606+
self, session: ProfileSession, wallet: BaseWallet, recipient_key: str
607+
) -> str:
608+
try:
609+
conn = await ConnRecord.retrieve_by_invitation_key(
610+
session, invitation_key=normalize_from_did_key(recipient_key)
611+
)
612+
except StorageNotFoundError:
613+
did_info = await wallet.get_local_did_for_verkey(
614+
normalize_from_did_key(recipient_key)
615+
)
616+
conn = await ConnRecord.retrieve_by_did(session, my_did=did_info.did)
617+
return conn.connection_id
618+
603619
async def notify_keylist_updated(
604620
self, connection_id: str, response: KeylistUpdateResponse
605621
):
606622
"""Notify of keylist update response received."""
607-
# Retrieve connection IDs associated with recipient keys
608-
# recipient key -> connection id
609623
async with self._profile.session() as session:
624+
wallet = session.inject(BaseWallet)
610625
try:
611-
routes = [
612-
await RouteRecord.retrieve_by_recipient_key(
613-
session, updated.recipient_key
626+
routes = {
627+
updated.recipient_key: await self._conn_id_from_recipient_key(
628+
session, wallet, updated.recipient_key
614629
)
615630
for updated in response.updated
616-
]
617-
except StorageNotFoundError as err:
631+
}
632+
except (StorageNotFoundError, WalletNotFoundError) as err:
618633
raise MediationManagerError(
619634
"Unknown recipient key received in keylist update response"
620635
) from err
@@ -625,9 +640,7 @@ async def notify_keylist_updated(
625640
"connection_id": connection_id,
626641
"thread_id": response._thread_id,
627642
"updated": [update.serialize() for update in response.updated],
628-
"mediated_connections": {
629-
route.recipient_key: route.connection_id for route in routes
630-
},
643+
"mediated_connections": routes,
631644
},
632645
)
633646

aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Test MediationManager."""
22
import logging
3-
from typing import AsyncIterable, Iterable
3+
from typing import AsyncGenerator, AsyncIterable, Iterable
44

5+
from functools import partial
56
from asynctest import mock as async_mock
67
import pytest
78

@@ -482,37 +483,40 @@ async def test_notify_keylist_updated(
482483
self,
483484
manager: MediationManager,
484485
mock_event_bus: MockEventBus,
485-
session: ProfileSession,
486486
):
487487
"""test notify_keylist_updated."""
488-
await RouteRecord(
489-
role=RouteRecord.ROLE_CLIENT,
490-
connection_id="conn_id_1",
491-
recipient_key=TEST_ROUTE_VERKEY,
492-
).save(session)
493-
await RouteRecord(
494-
role=RouteRecord.ROLE_CLIENT,
495-
connection_id="conn_id_2",
496-
recipient_key=TEST_VERKEY,
497-
).save(session)
498488

499-
response = KeylistUpdateResponse(
500-
updated=[
501-
KeylistUpdated(
502-
recipient_key=TEST_ROUTE_VERKEY,
503-
action=KeylistUpdateRule.RULE_ADD,
504-
result=KeylistUpdated.RESULT_SUCCESS,
505-
),
506-
KeylistUpdated(
507-
recipient_key=TEST_VERKEY,
508-
action=KeylistUpdateRule.RULE_REMOVE,
509-
result=KeylistUpdated.RESULT_SUCCESS,
510-
),
511-
],
512-
)
489+
async def _result_generator():
490+
yield "conn_id_1"
491+
yield "conn_id_2"
513492

514-
response.assign_thread_id(TEST_THREAD_ID)
515-
await manager.notify_keylist_updated(TEST_CONN_ID, response)
493+
async def _retrieve_by_invitation_key(
494+
generator: AsyncGenerator, *args, **kwargs
495+
):
496+
return await generator.__anext__()
497+
498+
with async_mock.patch.object(
499+
manager,
500+
"_conn_id_from_recipient_key",
501+
partial(_retrieve_by_invitation_key, _result_generator()),
502+
):
503+
response = KeylistUpdateResponse(
504+
updated=[
505+
KeylistUpdated(
506+
recipient_key=TEST_ROUTE_VERKEY,
507+
action=KeylistUpdateRule.RULE_ADD,
508+
result=KeylistUpdated.RESULT_SUCCESS,
509+
),
510+
KeylistUpdated(
511+
recipient_key=TEST_VERKEY,
512+
action=KeylistUpdateRule.RULE_REMOVE,
513+
result=KeylistUpdated.RESULT_SUCCESS,
514+
),
515+
],
516+
)
517+
518+
response.assign_thread_id(TEST_THREAD_ID)
519+
await manager.notify_keylist_updated(TEST_CONN_ID, response)
516520
assert mock_event_bus.events
517521
assert mock_event_bus.events[0][1].topic == manager.KEYLIST_UPDATED_EVENT
518522
assert mock_event_bus.events[0][1].payload == {

0 commit comments

Comments
 (0)