Skip to content

Commit 7bb7809

Browse files
committed
refactor: connection retrieval through route manager
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
1 parent 3d3f92b commit 7bb7809

7 files changed

Lines changed: 186 additions & 126 deletions

File tree

aries_cloudagent/multitenant/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,19 @@ async def _get_wallet_by_key(self, recipient_key: str) -> Optional[WalletRecord]
381381
except (RouteNotFoundError):
382382
pass
383383

384+
async def get_profile_for_key(
385+
self, context: InjectionContext, recipient_key: str
386+
) -> Optional[Profile]:
387+
"""Retrieve a wallet profile by recipient key."""
388+
wallet = await self._get_wallet_by_key(recipient_key)
389+
if not wallet:
390+
return None
391+
392+
if wallet.requires_external_key:
393+
raise WalletKeyMissingError()
394+
395+
return await self.get_wallet_profile(context, wallet)
396+
384397
async def get_wallets_by_message(
385398
self, message_body, wire_format: BaseWireFormat = None
386399
) -> List[WalletRecord]:

aries_cloudagent/multitenant/route_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from typing import List, Optional, Tuple
66

7+
from ..connections.models.conn_record import ConnRecord
78
from ..core.profile import Profile
89
from ..messaging.responder import BaseResponder
910
from ..protocols.coordinate_mediation.v1_0.manager import MediationManager
@@ -14,6 +15,7 @@
1415
from ..protocols.routing.v1_0.manager import RoutingManager
1516
from ..protocols.routing.v1_0.models.route_record import RouteRecord
1617
from ..storage.error import StorageNotFoundError
18+
from .manager import MultitenantManager
1719

1820

1921
LOGGER = logging.getLogger(__name__)
@@ -103,3 +105,23 @@ async def routing_info(
103105
my_endpoint = mediation_record.endpoint
104106

105107
return routing_keys, my_endpoint
108+
109+
async def connection_from_recipient_key(
110+
self, profile: Profile, recipient_key: str
111+
) -> ConnRecord:
112+
"""Retrieve a connection by recipient key.
113+
114+
The recipient key is expected to be a local key owned by this agent.
115+
116+
Since the multi-tenant base wallet can receive and send keylist updates
117+
for sub wallets, we check the sub wallet's connections before the base
118+
wallet.
119+
"""
120+
manager = MultitenantManager(self.root_profile)
121+
profile_to_search = (
122+
await manager.get_profile_for_key(profile.context, recipient_key) or profile
123+
)
124+
125+
return await super().connection_from_recipient_key(
126+
profile_to_search, recipient_key
127+
)

aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Handler for keylist-update-response message."""
22

3+
from .....core.profile import Profile
34
from .....messaging.base_handler import BaseHandler, HandlerException
45
from .....messaging.request_context import RequestContext
56
from .....messaging.responder import BaseResponder
6-
7-
from ..messages.keylist_update_response import KeylistUpdateResponse
7+
from .....storage.error import StorageNotFoundError
8+
from .....wallet.error import WalletNotFoundError
89
from ..manager import MediationManager
10+
from ..messages.keylist_update_response import KeylistUpdateResponse
11+
from ..route_manager import RouteManager
912

1013

1114
class KeylistUpdateResponseHandler(BaseHandler):
@@ -25,6 +28,35 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
2528
await mgr.store_update_results(
2629
context.connection_record.connection_id, context.message.updated
2730
)
28-
await mgr.notify_keylist_updated(
29-
context.connection_record.connection_id, context.message
31+
await self.notify_keylist_updated(
32+
context.profile, context.connection_record.connection_id, context.message
33+
)
34+
35+
async def notify_keylist_updated(
36+
self, profile: Profile, connection_id: str, response: KeylistUpdateResponse
37+
):
38+
"""Notify of keylist update response received."""
39+
route_manager = profile.inject(RouteManager)
40+
try:
41+
key_to_connection = {
42+
updated.recipient_key: await route_manager.connection_from_recipient_key(
43+
profile, updated.recipient_key
44+
)
45+
for updated in response.updated
46+
}
47+
except (StorageNotFoundError, WalletNotFoundError) as err:
48+
raise HandlerException(
49+
"Unknown recipient key received in keylist update response"
50+
) from err
51+
52+
await profile.notify(
53+
MediationManager.KEYLIST_UPDATED_EVENT,
54+
{
55+
"connection_id": connection_id,
56+
"thread_id": response._thread_id,
57+
"updated": [update.serialize() for update in response.updated],
58+
"mediated_connections": {
59+
key: conn.connection_id for key, conn in key_to_connection.items()
60+
},
61+
},
3062
)

aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
"""Test handler for keylist-update-response message."""
22

3+
from functools import partial
4+
from typing import AsyncGenerator
35
import pytest
46
from asynctest import TestCase as AsyncTestCase
57
from asynctest import mock as async_mock
68

9+
710
from ......connections.models.conn_record import ConnRecord
11+
from ......core.event_bus import EventBus, MockEventBus
812
from ......messaging.base_handler import HandlerException
913
from ......messaging.request_context import RequestContext
1014
from ......messaging.responder import MockResponder
1115
from ...messages.inner.keylist_update_rule import KeylistUpdateRule
1216
from ...messages.inner.keylist_updated import KeylistUpdated
1317
from ...messages.keylist_update_response import KeylistUpdateResponse
1418
from ...manager import MediationManager
19+
from ...route_manager import RouteManager
20+
from ...tests.test_route_manager import MockRouteManager
1521
from ..keylist_update_response_handler import KeylistUpdateResponseHandler
1622

1723
TEST_CONN_ID = "conn-id"
18-
TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"
24+
TEST_THREAD_ID = "thread-id"
25+
TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"
26+
TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya"
1927

2028

2129
class TestKeylistUpdateResponseHandler(AsyncTestCase):
@@ -34,6 +42,10 @@ async def setUp(self):
3442
self.context.message = KeylistUpdateResponse(updated=self.updated)
3543
self.context.connection_ready = True
3644
self.context.connection_record = ConnRecord(connection_id=TEST_CONN_ID)
45+
self.mock_event_bus = MockEventBus()
46+
self.context.injector.bind_instance(EventBus, self.mock_event_bus)
47+
self.route_manager = MockRouteManager()
48+
self.context.injector.bind_instance(RouteManager, self.route_manager)
3749

3850
async def test_handler_no_active_connection(self):
3951
handler, responder = KeylistUpdateResponseHandler(), MockResponder()
@@ -47,8 +59,86 @@ async def test_handler(self):
4759
with async_mock.patch.object(
4860
MediationManager, "store_update_results"
4961
) as mock_store, async_mock.patch.object(
50-
MediationManager, "notify_keylist_updated"
62+
handler, "notify_keylist_updated"
5163
) as mock_notify:
5264
await handler.handle(self.context, responder)
5365
mock_store.assert_called_once_with(TEST_CONN_ID, self.updated)
54-
mock_notify.assert_called_once_with(TEST_CONN_ID, self.context.message)
66+
mock_notify.assert_called_once_with(
67+
self.context.profile, TEST_CONN_ID, self.context.message
68+
)
69+
70+
async def test_notify_keylist_updated(self):
71+
"""test notify_keylist_updated."""
72+
handler = KeylistUpdateResponseHandler()
73+
74+
async def _result_generator():
75+
yield ConnRecord(connection_id="conn_id_1")
76+
yield ConnRecord(connection_id="conn_id_2")
77+
78+
async def _retrieve_by_invitation_key(
79+
generator: AsyncGenerator, *args, **kwargs
80+
):
81+
return await generator.__anext__()
82+
83+
with async_mock.patch.object(
84+
self.route_manager,
85+
"connection_from_recipient_key",
86+
partial(_retrieve_by_invitation_key, _result_generator()),
87+
):
88+
response = KeylistUpdateResponse(
89+
updated=[
90+
KeylistUpdated(
91+
recipient_key=TEST_ROUTE_VERKEY,
92+
action=KeylistUpdateRule.RULE_ADD,
93+
result=KeylistUpdated.RESULT_SUCCESS,
94+
),
95+
KeylistUpdated(
96+
recipient_key=TEST_VERKEY,
97+
action=KeylistUpdateRule.RULE_REMOVE,
98+
result=KeylistUpdated.RESULT_SUCCESS,
99+
),
100+
],
101+
)
102+
103+
response.assign_thread_id(TEST_THREAD_ID)
104+
await handler.notify_keylist_updated(
105+
self.context.profile, TEST_CONN_ID, response
106+
)
107+
assert self.mock_event_bus.events
108+
assert (
109+
self.mock_event_bus.events[0][1].topic
110+
== MediationManager.KEYLIST_UPDATED_EVENT
111+
)
112+
assert self.mock_event_bus.events[0][1].payload == {
113+
"connection_id": TEST_CONN_ID,
114+
"thread_id": TEST_THREAD_ID,
115+
"updated": [result.serialize() for result in response.updated],
116+
"mediated_connections": {
117+
TEST_ROUTE_VERKEY: "conn_id_1",
118+
TEST_VERKEY: "conn_id_2",
119+
},
120+
}
121+
122+
async def test_notify_keylist_updated_x_unknown_recip_key(self):
123+
"""test notify_keylist_updated."""
124+
handler = KeylistUpdateResponseHandler()
125+
response = KeylistUpdateResponse(
126+
updated=[
127+
KeylistUpdated(
128+
recipient_key=TEST_ROUTE_VERKEY,
129+
action=KeylistUpdateRule.RULE_ADD,
130+
result=KeylistUpdated.RESULT_SUCCESS,
131+
),
132+
KeylistUpdated(
133+
recipient_key=TEST_VERKEY,
134+
action=KeylistUpdateRule.RULE_REMOVE,
135+
result=KeylistUpdated.RESULT_SUCCESS,
136+
),
137+
],
138+
)
139+
140+
response.assign_thread_id(TEST_THREAD_ID)
141+
with pytest.raises(HandlerException):
142+
await handler.notify_keylist_updated(
143+
self.context.profile, TEST_CONN_ID, response
144+
)

aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
import logging
44
from typing import Optional, Sequence, Tuple
55

6-
76
from ....core.error import BaseError
87
from ....core.profile import Profile, ProfileSession
9-
from ....connections.models.conn_record import ConnRecord
108
from ....storage.base import BaseStorage
119
from ....storage.error import StorageNotFoundError
1210
from ....storage.record import StorageRecord
1311
from ....wallet.base import BaseWallet
1412
from ....wallet.did_info import DIDInfo
1513
from ....wallet.did_method import DIDMethod
16-
from ....wallet.error import WalletNotFoundError
1714
from ....wallet.key_type import KeyType
1815
from ...routing.v1_0.manager import RoutingManager
1916
from ...routing.v1_0.models.route_record import RouteRecord
@@ -602,48 +599,6 @@ async def store_update_results(
602599
for record_for_removal in to_remove:
603600
await record_for_removal.delete_record(session)
604601

605-
async def _conn_id_from_recipient_key(
606-
self, session: ProfileSession, wallet: BaseWallet, recipient_key: str
607-
) -> str:
608-
try:
609-
conn = await ConnRecord.retrieve_by_invitation_key(
610-
session, invitation_key=normalize_from_did_key(recipient_key)
611-
)
612-
except StorageNotFoundError:
613-
did_info = await wallet.get_local_did_for_verkey(
614-
normalize_from_did_key(recipient_key)
615-
)
616-
conn = await ConnRecord.retrieve_by_did(session, my_did=did_info.did)
617-
return conn.connection_id
618-
619-
async def notify_keylist_updated(
620-
self, connection_id: str, response: KeylistUpdateResponse
621-
):
622-
"""Notify of keylist update response received."""
623-
async with self._profile.session() as session:
624-
wallet = session.inject(BaseWallet)
625-
try:
626-
routes = {
627-
updated.recipient_key: await self._conn_id_from_recipient_key(
628-
session, wallet, updated.recipient_key
629-
)
630-
for updated in response.updated
631-
}
632-
except (StorageNotFoundError, WalletNotFoundError) as err:
633-
raise MediationManagerError(
634-
"Unknown recipient key received in keylist update response"
635-
) from err
636-
637-
await self._profile.notify(
638-
self.KEYLIST_UPDATED_EVENT,
639-
{
640-
"connection_id": connection_id,
641-
"thread_id": response._thread_id,
642-
"updated": [update.serialize() for update in response.updated],
643-
"mediated_connections": routes,
644-
},
645-
)
646-
647602
async def get_my_keylist(
648603
self, connection_id: Optional[str] = None
649604
) -> Sequence[RouteRecord]:

aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .manager import MediationManager
2121
from .messages.keylist_update import KeylistUpdate
2222
from .models.mediation_record import MediationRecord
23+
from .normalization import normalize_from_did_key
2324

2425

2526
LOGGER = logging.getLogger(__name__)
@@ -242,6 +243,27 @@ async def routing_info(
242243
) -> Tuple[List[str], str]:
243244
"""Retrieve routing keys."""
244245

246+
async def connection_from_recipient_key(
247+
self, profile: Profile, recipient_key: str
248+
) -> ConnRecord:
249+
"""Retrieve connection for a recipient_key.
250+
251+
The recipient key is expected to be a local key owned by this agent.
252+
"""
253+
async with profile.session() as session:
254+
wallet = session.inject(BaseWallet)
255+
try:
256+
conn = await ConnRecord.retrieve_by_invitation_key(
257+
session, invitation_key=normalize_from_did_key(recipient_key)
258+
)
259+
except StorageNotFoundError:
260+
did_info = await wallet.get_local_did_for_verkey(
261+
normalize_from_did_key(recipient_key)
262+
)
263+
conn = await ConnRecord.retrieve_by_did(session, my_did=did_info.did)
264+
265+
return conn
266+
245267

246268
class CoordinateMediationV1RouteManager(RouteManager):
247269
"""Manage routes using Coordinate Mediation protocol."""

0 commit comments

Comments
 (0)