2929from pydantic import BaseModel
3030
3131from ...events .event import Event
32+ from ...sessions import _session_util
3233from ...sessions .base_session_service import BaseSessionService
3334from ...sessions .base_session_service import GetSessionConfig
3435from ...sessions .base_session_service import ListSessionsResponse
3536from ...sessions .session import Session
37+ from ...sessions .state import State
3638
3739logger = logging .getLogger ("google_adk." + __name__ )
3840
4648class FirestoreSessionService (BaseSessionService ): # type: ignore[misc]
4749 """Session service that uses Google Cloud Firestore as the backend.
4850
49- It creates a hierarchy in Firestore to hold events by app, user, and session :
51+ Hierarchy for sessions :
5052 adk-session
5153 ↳ <app name>
5254 ↳ users
@@ -55,7 +57,15 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc]
5557 ↳ <session ID>
5658 ↳ events
5759 ↳ <event ID>
58- ↳ Event document
60+
61+ Hierarchy for shared App/User state configurations:
62+ app_states
63+ ↳ <app name>
64+
65+ user_states
66+ ↳ <app name>
67+ ↳ users
68+ ↳ <user ID>
5969 """
6070
6171 def __init__ (
@@ -101,9 +111,9 @@ def _merge_state(
101111
102112 merged_state = copy .deepcopy (session_state )
103113 for key , value in app_state .items ():
104- merged_state ["_app_" + key ] = value
114+ merged_state [State . APP_PREFIX + key ] = value
105115 for key , value in user_state .items ():
106- merged_state ["_user_" + key ] = value
116+ merged_state [State . USER_PREFIX + key ] = value
107117 return merged_state
108118
109119 def _get_sessions_ref (
@@ -138,38 +148,91 @@ async def create_session(
138148
139149 session_ref = self ._get_sessions_ref (app_name , user_id ).document (session_id )
140150
151+ # Extract state deltas
152+ state_deltas = _session_util .extract_state_delta (initial_state )
153+ app_state_delta = state_deltas ["app" ]
154+ user_state_delta = state_deltas ["user" ]
155+ session_state = state_deltas ["session" ]
156+
157+ app_ref = self .client .collection (self .app_state_collection ).document (
158+ app_name
159+ )
160+ user_ref = (
161+ self .client .collection (self .user_state_collection )
162+ .document (app_name )
163+ .collection ("users" )
164+ .document (user_id )
165+ )
166+
141167 session_data = {
142168 "id" : session_id ,
143169 "appName" : app_name ,
144170 "userId" : user_id ,
145- "state" : initial_state ,
171+ "state" : session_state ,
146172 "createTime" : now ,
147173 "updateTime" : now ,
148174 }
149175
150176 @firestore .async_transactional # type: ignore[untyped-decorator]
151177 async def _create_txn (transaction : firestore .AsyncTransaction ) -> None :
178+ # 1. Reads
152179 snap = await session_ref .get (transaction = transaction )
153180 if snap .exists :
154181 from ...errors .already_exists_error import AlreadyExistsError
155182
156183 raise AlreadyExistsError (f"Session { session_id } already exists." )
184+
185+ app_snap = (
186+ await app_ref .get (transaction = transaction )
187+ if app_state_delta
188+ else None
189+ )
190+ user_snap = (
191+ await user_ref .get (transaction = transaction )
192+ if user_state_delta
193+ else None
194+ )
195+
196+ # 2. Writes
197+ if app_state_delta :
198+ current_app = (
199+ app_snap .to_dict () if (app_snap and app_snap .exists ) else {}
200+ )
201+ current_app .update (app_state_delta )
202+ transaction .set (app_ref , current_app , merge = True )
203+
204+ if user_state_delta :
205+ current_user = (
206+ user_snap .to_dict () if (user_snap and user_snap .exists ) else {}
207+ )
208+ current_user .update (user_state_delta )
209+ transaction .set (user_ref , current_user , merge = True )
210+
157211 transaction .set (session_ref , session_data )
158212
159213 transaction_obj = self .client .transaction ()
160214 await _create_txn (transaction_obj )
161215
162- # We need a timestamp for the Session object. Since SERVER_TIMESTAMP is
163- # evaluated on the server, we might want to use local time for the object
164- # or read it back. Reading it back is expensive. We'll use local time for
165- # the object, but the DB will have SERVER_TIMESTAMP.
216+ storage_app_doc = await app_ref .get ()
217+ storage_app_state = (
218+ storage_app_doc .to_dict () if storage_app_doc .exists else {}
219+ )
220+ storage_user_doc = await user_ref .get ()
221+ storage_user_state = (
222+ storage_user_doc .to_dict () if storage_user_doc .exists else {}
223+ )
224+
225+ merged_state = self ._merge_state (
226+ storage_app_state , storage_user_state , session_state
227+ )
228+
166229 local_now = datetime .now (timezone .utc ).timestamp ()
167230
168231 return Session (
169232 id = session_id ,
170233 app_name = app_name ,
171234 user_id = user_id ,
172- state = initial_state ,
235+ state = merged_state ,
173236 events = [],
174237 last_update_time = local_now ,
175238 )
@@ -215,6 +278,23 @@ async def get_session(
215278 # Let's continue getting session.
216279 session_state = data .get ("state" , {})
217280
281+ # Fetch shared state
282+ app_ref = self .client .collection (self .app_state_collection ).document (
283+ app_name
284+ )
285+ user_ref = (
286+ self .client .collection (self .user_state_collection )
287+ .document (app_name )
288+ .collection ("users" )
289+ .document (user_id )
290+ )
291+ app_doc = await app_ref .get ()
292+ app_state = app_doc .to_dict () if app_doc .exists else {}
293+ user_doc = await user_ref .get ()
294+ user_state = user_doc .to_dict () if user_doc .exists else {}
295+
296+ merged_state = self ._merge_state (app_state , user_state , session_state )
297+
218298 # Convert timestamp
219299 update_time = data .get ("updateTime" )
220300 last_update_time = 0.0
@@ -231,7 +311,7 @@ async def get_session(
231311 id = session_id ,
232312 app_name = app_name ,
233313 user_id = user_id ,
234- state = session_state ,
314+ state = merged_state ,
235315 events = events ,
236316 last_update_time = last_update_time ,
237317 )
@@ -339,17 +419,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
339419
340420 if event .actions and event .actions .state_delta :
341421 state_delta = event .actions .state_delta
342- app_updates = {}
343- user_updates = {}
344- session_updates = {}
345-
346- for key , value in state_delta .items ():
347- if key .startswith ("_app_" ):
348- app_updates [key [len ("_app_" ) :]] = value
349- elif key .startswith ("_user_" ):
350- user_updates [key [len ("_user_" ) :]] = value
351- else :
352- session_updates [key ] = value
422+ state_deltas = _session_util .extract_state_delta (state_delta )
423+ app_updates = state_deltas ["app" ]
424+ user_updates = state_deltas ["user" ]
425+ session_updates = state_deltas ["session" ]
353426
354427 app_ref = self .client .collection (self .app_state_collection ).document (
355428 session .app_name
0 commit comments