@@ -28,31 +28,35 @@ def mock_firestore_client():
2828 doc_ref = mock .MagicMock ()
2929 subcollection_ref = mock .MagicMock ()
3030 subdoc_ref = mock .MagicMock ()
31+ sessions_coll_ref = mock .MagicMock ()
32+ sessions_doc_ref = mock .MagicMock ()
3133
3234 client .collection .return_value = collection_ref
3335 collection_ref .document .return_value = doc_ref
3436 doc_ref .collection .return_value = subcollection_ref
3537 subcollection_ref .document .return_value = subdoc_ref
38+ subdoc_ref .collection .return_value = sessions_coll_ref
39+ sessions_coll_ref .document .return_value = sessions_doc_ref
3640
3741 doc_snapshot = mock .MagicMock ()
3842 doc_snapshot .exists = False
3943 doc_snapshot .to_dict .return_value = {}
4044
41- doc_ref .get = mock .AsyncMock (return_value = doc_snapshot )
4245 subdoc_ref .get = mock .AsyncMock (return_value = doc_snapshot )
46+ sessions_doc_ref .get = mock .AsyncMock (return_value = doc_snapshot )
4347
44- subdoc_ref .set = mock .AsyncMock ()
45- subdoc_ref .delete = mock .AsyncMock ()
48+ sessions_doc_ref .set = mock .AsyncMock ()
49+ sessions_doc_ref .delete = mock .AsyncMock ()
4650
4751 events_collection_ref = mock .MagicMock ()
48- subdoc_ref .collection .return_value = events_collection_ref
52+ sessions_doc_ref .collection .return_value = events_collection_ref
4953 events_collection_ref .order_by .return_value = events_collection_ref
5054 events_collection_ref .where .return_value = events_collection_ref
5155 events_collection_ref .limit_to_last .return_value = events_collection_ref
5256 events_collection_ref .get = mock .AsyncMock (return_value = [])
5357
54- subcollection_ref .get = mock .AsyncMock (return_value = [])
55- subcollection_ref .where .return_value = subcollection_ref
58+ sessions_coll_ref .get = mock .AsyncMock (return_value = [])
59+ sessions_coll_ref .where .return_value = sessions_coll_ref
5660
5761 client .collection_group .return_value = collection_ref
5862
@@ -92,11 +96,15 @@ async def test_create_session(mock_firestore_client):
9296 assert session .id
9397
9498 mock_firestore_client .collection .assert_called_once_with ("adk-session" )
95- collection_ref = mock_firestore_client .collection .return_value
96- collection_ref .document .assert_called_once_with (user_id )
97- doc_ref = collection_ref .document .return_value
98- doc_ref .collection .assert_called_once_with ("sessions" )
99- sessions_ref = doc_ref .collection .return_value
99+ root_coll = mock_firestore_client .collection .return_value
100+ root_coll .document .assert_called_once_with (app_name )
101+ app_ref = root_coll .document .return_value
102+ app_ref .collection .assert_called_once_with ("users" )
103+ users_coll = app_ref .collection .return_value
104+ users_coll .document .assert_called_once_with (user_id )
105+ user_ref = users_coll .document .return_value
106+ user_ref .collection .assert_called_once_with ("sessions" )
107+ sessions_ref = user_ref .collection .return_value
100108 sessions_ref .document .assert_called_once_with (session .id )
101109 session_doc_ref = sessions_ref .document .return_value
102110 from google .cloud import firestore
@@ -125,11 +133,15 @@ async def test_get_session_not_found(mock_firestore_client):
125133 assert session is None
126134
127135 mock_firestore_client .collection .assert_called_with ("adk-session" )
128- collection_ref = mock_firestore_client .collection .return_value
129- collection_ref .document .assert_called_with (user_id )
130- doc_ref = collection_ref .document .return_value
131- doc_ref .collection .assert_called_with ("sessions" )
132- sessions_ref = doc_ref .collection .return_value
136+ root_coll = mock_firestore_client .collection .return_value
137+ root_coll .document .assert_called_with (app_name )
138+ app_ref = root_coll .document .return_value
139+ app_ref .collection .assert_called_with ("users" )
140+ users_coll = app_ref .collection .return_value
141+ users_coll .document .assert_called_with (user_id )
142+ user_ref = users_coll .document .return_value
143+ user_ref .collection .assert_called_with ("sessions" )
144+ sessions_ref = user_ref .collection .return_value
133145 sessions_ref .document .assert_called_with (session_id )
134146
135147
@@ -153,7 +165,7 @@ async def test_get_session_found(mock_firestore_client):
153165 }
154166
155167 events_collection_ref = (
156- mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value
168+ mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value . document . return_value . collection . return_value
157169 )
158170 event_doc = mock .MagicMock ()
159171 event_doc .to_dict .return_value = {
@@ -180,7 +192,7 @@ async def test_delete_session(mock_firestore_client):
180192 session_id = "test_session"
181193
182194 events_ref = (
183- mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value
195+ mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value . document . return_value . collection . return_value
184196 )
185197 event_doc = mock .AsyncMock ()
186198
@@ -201,7 +213,7 @@ async def to_async_iter(iterable):
201213 batch .commit .assert_called_once ()
202214
203215 session_doc_ref = (
204- mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value
216+ mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value . collection . return_value . document . return_value
205217 )
206218 session_doc_ref .delete .assert_called_once ()
207219
@@ -334,10 +346,14 @@ def collection_side_effect(name):
334346 users_coll .document .return_value = user_doc_ref
335347 user_doc_ref .get = mock .AsyncMock (return_value = user_doc )
336348
337- user_doc_in_sessions = mock .MagicMock ()
338- sessions_coll .document .return_value = user_doc_in_sessions
349+ app_doc_in_root = mock .MagicMock ()
350+ sessions_coll .document .return_value = app_doc_in_root
351+ users_coll = mock .MagicMock ()
352+ app_doc_in_root .collection .return_value = users_coll
353+ user_doc_in_users = mock .MagicMock ()
354+ users_coll .document .return_value = user_doc_in_users
339355 sessions_subcoll = mock .MagicMock ()
340- user_doc_in_sessions .collection .return_value = sessions_subcoll
356+ user_doc_in_users .collection .return_value = sessions_subcoll
341357 sessions_query = mock .MagicMock ()
342358 sessions_subcoll .where .return_value = sessions_query
343359 sessions_query .get = mock .AsyncMock (return_value = [session_doc ])
@@ -448,7 +464,7 @@ async def test_get_session_with_config(mock_firestore_client):
448464 }
449465
450466 events_collection_ref = (
451- mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value
467+ mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value . document . return_value . collection . return_value
452468 )
453469
454470 from google .adk .sessions .base_session_service import GetSessionConfig
@@ -471,7 +487,7 @@ async def test_delete_session_batching(mock_firestore_client):
471487 session_id = "test_session"
472488
473489 events_ref = (
474- mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value
490+ mock_firestore_client .collection .return_value .document .return_value .collection .return_value .document .return_value .collection .return_value . document . return_value . collection . return_value
475491 )
476492
477493 dummy_docs = [mock .MagicMock () for _ in range (501 )]
@@ -631,10 +647,14 @@ def collection_side_effect(name):
631647 users_coll .document .return_value = user_doc_ref
632648 user_doc_ref .get = mock .AsyncMock (return_value = user_doc )
633649
634- user_doc_in_sessions = mock .MagicMock ()
635- sessions_coll .document .return_value = user_doc_in_sessions
650+ app_doc_in_root = mock .MagicMock ()
651+ sessions_coll .document .return_value = app_doc_in_root
652+ users_coll = mock .MagicMock ()
653+ app_doc_in_root .collection .return_value = users_coll
654+ user_doc_in_users = mock .MagicMock ()
655+ users_coll .document .return_value = user_doc_in_users
636656 sessions_subcoll = mock .MagicMock ()
637- user_doc_in_sessions .collection .return_value = sessions_subcoll
657+ user_doc_in_users .collection .return_value = sessions_subcoll
638658 sessions_query = mock .MagicMock ()
639659 sessions_subcoll .where .return_value = sessions_query
640660 sessions_query .get = mock .AsyncMock (return_value = [session_doc ])
0 commit comments