@@ -89,7 +89,8 @@ async def test_create_session(mock_firestore_client):
8989 app_name = "test_app"
9090 user_id = "test_user"
9191
92- session = await service .create_session (app_name = app_name , user_id = user_id )
92+ with mock .patch ("google.cloud.firestore.async_transactional" , lambda x : x ):
93+ session = await service .create_session (app_name = app_name , user_id = user_id )
9394
9495 assert session .app_name == app_name
9596 assert session .user_id == user_id
@@ -109,14 +110,16 @@ async def test_create_session(mock_firestore_client):
109110 session_doc_ref = sessions_ref .document .return_value
110111 from google .cloud import firestore
111112
112- session_doc_ref .set .assert_called_once ()
113- args , kwargs = session_doc_ref .set .call_args
114- assert args [0 ]["id" ] == session .id
115- assert args [0 ]["appName" ] == app_name
116- assert args [0 ]["userId" ] == user_id
117- assert args [0 ]["state" ] == {}
118- assert args [0 ]["createTime" ] == firestore .SERVER_TIMESTAMP
119- assert args [0 ]["updateTime" ] == firestore .SERVER_TIMESTAMP
113+ transaction = mock_firestore_client .transaction .return_value
114+ transaction .set .assert_called_once ()
115+ args , kwargs = transaction .set .call_args
116+ assert args [0 ] == session_doc_ref
117+ assert args [1 ]["id" ] == session .id
118+ assert args [1 ]["appName" ] == app_name
119+ assert args [1 ]["userId" ] == user_id
120+ assert args [1 ]["state" ] == {}
121+ assert args [1 ]["createTime" ] == firestore .SERVER_TIMESTAMP
122+ assert args [1 ]["updateTime" ] == firestore .SERVER_TIMESTAMP
120123
121124
122125@pytest .mark .asyncio
@@ -228,7 +231,8 @@ async def test_append_event(mock_firestore_client):
228231 session = Session (id = "test_session" , app_name = app_name , user_id = user_id )
229232 event = Event (invocation_id = "test_inv" , author = "user" )
230233
231- await service .append_event (session , event )
234+ with mock .patch ("google.cloud.firestore.async_transactional" , lambda x : x ):
235+ await service .append_event (session , event )
232236
233237 from google .cloud import firestore
234238
@@ -273,31 +277,22 @@ async def test_append_event_with_state_delta(mock_firestore_client):
273277 service ._update_app_state_transactional = mock .AsyncMock ()
274278 service ._update_user_state_transactional = mock .AsyncMock ()
275279
276- await service .append_event (session , event )
277-
278- mock_firestore_client .batch .assert_called_once ()
279- service ._update_app_state_transactional .assert_called_once_with (
280- "test_app" , {"my_key" : "app_val" }
281- )
282- service ._update_user_state_transactional .assert_called_once_with (
283- "test_app" , "test_user" , {"my_key" : "user_val" }
284- )
285-
286- batch = mock_firestore_client .batch .return_value
280+ with mock .patch ("google.cloud.firestore.async_transactional" , lambda x : x ):
281+ await service .append_event (session , event )
287282
288- batch .set .assert_called ()
283+ transaction = mock_firestore_client .transaction .return_value
284+ transaction .set .assert_called ()
289285
290286 assert session .state ["session_key" ] == "session_val"
291287
292288 from google .cloud import firestore
293289
294- batch .update .assert_called_once ()
295- args , kwargs = batch .update .call_args
290+ transaction .update .assert_called_once ()
291+ args , kwargs = transaction .update .call_args
292+ assert args [0 ] == session_ref
296293 assert args [1 ]["state" ] == session .state
297294 assert args [1 ]["updateTime" ] == firestore .SERVER_TIMESTAMP
298295
299- batch .commit .assert_called_once ()
300-
301296
302297@pytest .mark .asyncio
303298async def test_list_sessions_with_user_id (mock_firestore_client ):
@@ -497,10 +492,11 @@ async def test_create_session_already_exists(mock_firestore_client):
497492
498493 from google .adk .errors .already_exists_error import AlreadyExistsError
499494
500- with pytest .raises (AlreadyExistsError ):
501- await service .create_session (
502- app_name = app_name , user_id = user_id , session_id = "existing_id"
503- )
495+ with mock .patch ("google.cloud.firestore.async_transactional" , lambda x : x ):
496+ with pytest .raises (AlreadyExistsError ):
497+ await service .create_session (
498+ app_name = app_name , user_id = user_id , session_id = "existing_id"
499+ )
504500
505501
506502@pytest .mark .asyncio
0 commit comments