2121from typing import Any
2222from typing import Optional
2323
24- from google .cloud import firestore
2524from pydantic import BaseModel
2625
2726from ...events .event import Event
@@ -55,6 +54,14 @@ def __init__(
5554 root_collection: The root collection name. Defaults to 'adk-session' or
5655 the value of ADK_FIRESTORE_ROOT_COLLECTION env var.
5756 """
57+ try :
58+ from google .cloud import firestore
59+ except ImportError as e :
60+ raise ImportError (
61+ "FirestoreSessionService requires google-cloud-firestore. "
62+ "Install it with: pip install google-cloud-firestore"
63+ ) from e
64+
5865 self .client = client or firestore .AsyncClient ()
5966 self .root_collection = (
6067 root_collection
@@ -66,6 +73,22 @@ def __init__(
6673 self .app_state_collection = DEFAULT_APP_STATE_COLLECTION
6774 self .user_state_collection = DEFAULT_USER_STATE_COLLECTION
6875
76+ @staticmethod
77+ def _merge_state (
78+ app_state : dict [str , Any ],
79+ user_state : dict [str , Any ],
80+ session_state : dict [str , Any ],
81+ ) -> dict [str , Any ]:
82+ """Merge app, user, and session states into a single state dictionary."""
83+ import copy
84+
85+ merged_state = copy .deepcopy (session_state )
86+ for key , value in app_state .items ():
87+ merged_state ["_app_" + key ] = value
88+ for key , value in user_state .items ():
89+ merged_state ["_user_" + key ] = value
90+ return merged_state
91+
6992 def _get_sessions_ref (
7093 self , user_id : str
7194 ) -> firestore .AsyncCollectionReference :
@@ -84,6 +107,7 @@ async def create_session(
84107 session_id : Optional [str ] = None ,
85108 ) -> Session :
86109 """Creates a new session in Firestore."""
110+ from google .cloud import firestore
87111 if not session_id :
88112 from ...platform import uuid as platform_uuid
89113
@@ -202,17 +226,50 @@ async def list_sessions(
202226 )
203227 docs = await query .get ()
204228
229+ # Fetch shared state once
230+ app_ref = self .client .collection (self .app_state_collection ).document (
231+ app_name
232+ )
233+ app_doc = await app_ref .get ()
234+ app_state = app_doc .to_dict () if app_doc .exists else {}
235+
236+ user_states_map = {}
237+ if user_id :
238+ user_ref = (
239+ self .client .collection (self .user_state_collection )
240+ .document (app_name )
241+ .collection ("users" )
242+ .document (user_id )
243+ )
244+ user_doc = await user_ref .get ()
245+ if user_doc .exists :
246+ user_states_map [user_id ] = user_doc .to_dict ()
247+ else :
248+ users_ref = (
249+ self .client .collection (self .user_state_collection )
250+ .document (app_name )
251+ .collection ("users" )
252+ )
253+ users_docs = await users_ref .get ()
254+ for u_doc in users_docs :
255+ user_states_map [u_doc .id ] = u_doc .to_dict ()
256+
205257 sessions = []
206258 for doc in docs :
207259 data = doc .to_dict ()
208260 if data :
261+ u_id = data ["userId" ]
262+ s_state = data .get ("state" , {})
263+ u_state = user_states_map .get (u_id , {})
264+ merged = self ._merge_state (app_state , u_state , s_state )
265+
209266 sessions .append (
210267 Session (
211268 id = data ["id" ],
212269 app_name = data ["appName" ],
213270 user_id = data ["userId" ],
214- state = {}, # Empty state for listing
215- events = [], # Empty events for listing
271+ state = merged ,
272+ events = [],
216273 last_update_time = 0.0 ,
217274 )
218275 )
@@ -226,17 +283,65 @@ async def delete_session(
226283 session_ref = self ._get_sessions_ref (user_id ).document (session_id )
227284
228285 events_ref = session_ref .collection (self .events_collection )
229- events_docs = await events_ref .get ()
230-
286+
231287 batch = self .client .batch ()
232- for event_doc in events_docs :
288+ count = 0
289+ async for event_doc in events_ref .stream ():
233290 batch .delete (event_doc .reference )
234- await batch .commit ()
291+ count += 1
292+ if count >= 500 :
293+ await batch .commit ()
294+ batch = self .client .batch ()
295+ count = 0
296+ if count > 0 :
297+ await batch .commit ()
235298
236299 await session_ref .delete ()
237300
301+ async def _update_app_state_transactional (
302+ self , app_name : str , delta : dict [str , Any ]
303+ ) -> dict [str , Any ]:
304+ """Atomically applies delta to app state inside a transaction."""
305+ from google .cloud import firestore
306+ doc_ref = self .client .collection (self .app_state_collection ).document (app_name )
307+
308+ @firestore .async_transactional
309+ async def _txn (transaction ):
310+ snap = await doc_ref .get (transaction = transaction )
311+ current = snap .to_dict () if snap .exists else {}
312+ current .update (delta )
313+ transaction .set (doc_ref , current , merge = True )
314+ return current
315+
316+ transaction = self .client .transaction ()
317+ return await _txn (transaction )
318+
319+ async def _update_user_state_transactional (
320+ self , app_name : str , user_id : str , delta : dict [str , Any ]
321+ ) -> dict [str , Any ]:
322+ """Atomically applies delta to user state inside a transaction."""
323+ from google .cloud import firestore
324+ doc_ref = (
325+ self .client .collection (self .user_state_collection )
326+ .document (app_name )
327+ .collection ("users" )
328+ .document (user_id )
329+ )
330+
331+ @firestore .async_transactional
332+ async def _txn (transaction ):
333+ snap = await doc_ref .get (transaction = transaction )
334+ current = snap .to_dict () if snap .exists else {}
335+ current .update (delta )
336+ transaction .set (doc_ref , current , merge = True )
337+ return current
338+
339+ transaction = self .client .transaction ()
340+ return await _txn (transaction )
341+
238342 async def append_event (self , session : Session , event : Event ) -> Event :
239343 """Appends an event to a session in Firestore."""
344+ from google .cloud import firestore
240345 if event .partial :
241346 return event
242347
@@ -259,26 +364,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
259364 else :
260365 session_updates [key ] = value
261366
262- batch = self .client .batch ()
263-
264367 if app_updates :
265- app_ref = self .client .collection (self .app_state_collection ).document (
266- session .app_name
267- )
268- batch .set (app_ref , app_updates , merge = True )
368+ await self ._update_app_state_transactional (session .app_name , app_updates )
269369
270370 if user_updates :
271- user_ref = (
272- self .client .collection (self .user_state_collection )
273- .document (session .app_name )
274- .collection ("users" )
275- .document (session .user_id )
276- )
277- batch .set (user_ref , user_updates , merge = True )
371+ await self ._update_user_state_transactional (session .app_name , session .user_id , user_updates )
278372
279373 for k , v in session_updates .items ():
280374 session .state [k ] = v
281375
376+ batch = self .client .batch ()
282377 batch .update (
283378 session_ref ,
284379 {
0 commit comments