1414
1515from __future__ import annotations
1616
17+ import asyncio
18+ from contextlib import asynccontextmanager
1719from datetime import datetime
1820from datetime import timezone
1921import logging
2022import os
2123from typing import Any
24+ from typing import AsyncIterator
2225from typing import cast
2326from typing import Optional
2427from typing import TYPE_CHECKING
2528
29+ _SessionLockKey = tuple [str , str , str ]
30+
2631if TYPE_CHECKING :
2732 from google .cloud import firestore
2833
@@ -96,10 +101,40 @@ def __init__(
96101 or DEFAULT_ROOT_COLLECTION
97102 )
98103 self .sessions_collection = DEFAULT_SESSIONS_COLLECTION
104+
105+ # Per-session locks used to serialize append_event calls in this process.
106+ self ._session_locks : dict [_SessionLockKey , asyncio .Lock ] = {}
107+ self ._session_lock_ref_count : dict [_SessionLockKey , int ] = {}
108+ self ._session_locks_guard = asyncio .Lock ()
99109 self .events_collection = DEFAULT_EVENTS_COLLECTION
100110 self .app_state_collection = DEFAULT_APP_STATE_COLLECTION
101111 self .user_state_collection = DEFAULT_USER_STATE_COLLECTION
102112
113+ @asynccontextmanager
114+ async def _with_session_lock (
115+ self , * , app_name : str , user_id : str , session_id : str
116+ ) -> AsyncIterator [None ]:
117+ """Serializes event appends for the same session within this process."""
118+ lock_key = (app_name , user_id , session_id )
119+ async with self ._session_locks_guard :
120+ lock = self ._session_locks .get (lock_key , asyncio .Lock ())
121+ self ._session_locks [lock_key ] = lock
122+ self ._session_lock_ref_count [lock_key ] = (
123+ self ._session_lock_ref_count .get (lock_key , 0 ) + 1
124+ )
125+
126+ try :
127+ async with lock :
128+ yield
129+ finally :
130+ async with self ._session_locks_guard :
131+ remaining = self ._session_lock_ref_count .get (lock_key , 0 ) - 1
132+ if remaining <= 0 and not lock .locked ():
133+ self ._session_lock_ref_count .pop (lock_key , None )
134+ self ._session_locks .pop (lock_key , None )
135+ else :
136+ self ._session_lock_ref_count [lock_key ] = remaining
137+
103138 @staticmethod
104139 def _merge_state (
105140 app_state : dict [str , Any ],
@@ -171,6 +206,7 @@ async def create_session(
171206 "state" : session_state ,
172207 "createTime" : now ,
173208 "updateTime" : now ,
209+ "revision" : 1 ,
174210 }
175211
176212 @firestore .async_transactional # type: ignore[untyped-decorator]
@@ -228,14 +264,16 @@ async def _create_txn(transaction: firestore.AsyncTransaction) -> None:
228264
229265 local_now = datetime .now (timezone .utc ).timestamp ()
230266
231- return Session (
267+ session = Session (
232268 id = session_id ,
233269 app_name = app_name ,
234270 user_id = user_id ,
235271 state = merged_state ,
236272 events = [],
237273 last_update_time = local_now ,
238274 )
275+ session ._storage_update_marker = "1"
276+ return session
239277
240278 async def get_session (
241279 self ,
@@ -307,14 +345,19 @@ async def get_session(
307345 except (ValueError , TypeError ):
308346 pass
309347
310- return Session (
348+ current_revision = data .get ("revision" , 0 )
349+ session = Session (
311350 id = session_id ,
312351 app_name = app_name ,
313352 user_id = user_id ,
314353 state = merged_state ,
315354 events = events ,
316355 last_update_time = last_update_time ,
317356 )
357+ session ._storage_update_marker = (
358+ str (current_revision ) if current_revision > 0 else None
359+ )
360+ return session
318361
319362 async def list_sessions (
320363 self , * , app_name : str , user_id : Optional [str ] = None
@@ -385,8 +428,24 @@ async def delete_session(
385428 self , * , app_name : str , user_id : str , session_id : str
386429 ) -> None :
387430 """Deletes a session and its events from Firestore."""
431+ from google .cloud import firestore
432+
388433 session_ref = self ._get_sessions_ref (app_name , user_id ).document (session_id )
389434
435+ @firestore .async_transactional # type: ignore[untyped-decorator]
436+ async def _mark_deleting_txn (
437+ transaction : firestore .AsyncTransaction ,
438+ ) -> None :
439+ snap = await session_ref .get (transaction = transaction )
440+ if snap .exists :
441+ transaction .update (session_ref , {"status" : "DELETING" })
442+
443+ try :
444+ transaction_obj = self .client .transaction ()
445+ await _mark_deleting_txn (transaction_obj )
446+ except Exception :
447+ pass
448+
390449 events_ref = session_ref .collection (self .events_collection )
391450
392451 batch = self .client .batch ()
@@ -417,26 +476,52 @@ async def append_event(self, session: Session, event: Event) -> Event:
417476 session .app_name , session .user_id
418477 ).document (session .id )
419478
420- if event .actions and event .actions .state_delta :
421- state_delta = event .actions .state_delta
422- state_deltas = _session_util .extract_state_delta (state_delta )
423- app_updates = state_deltas ["app" ]
424- user_updates = state_deltas ["user" ]
425- session_updates = state_deltas ["session" ]
479+ state_delta = (
480+ event .actions .state_delta
481+ if event .actions and event .actions .state_delta
482+ else {}
483+ )
484+ state_deltas = _session_util .extract_state_delta (state_delta )
485+ app_updates = state_deltas ["app" ]
486+ user_updates = state_deltas ["user" ]
487+ session_updates = state_deltas ["session" ]
426488
427- app_ref = self .client .collection (self .app_state_collection ).document (
428- session .app_name
429- )
430- user_ref = (
431- self .client .collection (self .user_state_collection )
432- .document (session .app_name )
433- .collection ("users" )
434- .document (session .user_id )
435- )
489+ app_ref = self .client .collection (self .app_state_collection ).document (
490+ session .app_name
491+ )
492+ user_ref = (
493+ self .client .collection (self .user_state_collection )
494+ .document (session .app_name )
495+ .collection ("users" )
496+ .document (session .user_id )
497+ )
498+
499+ async with self ._with_session_lock (
500+ app_name = session .app_name ,
501+ user_id = session .user_id ,
502+ session_id = session .id ,
503+ ):
436504
437505 @firestore .async_transactional # type: ignore[untyped-decorator]
438- async def _append_txn (transaction : firestore .AsyncTransaction ) -> None :
506+ async def _append_txn (transaction : firestore .AsyncTransaction ) -> int :
439507 # 1. Reads
508+ session_snap = await session_ref .get (transaction = transaction )
509+ if not session_snap .exists :
510+ raise ValueError (f"Session { session .id } not found." )
511+
512+ session_doc = session_snap .to_dict () or {}
513+ if session_doc .get ("status" ) == "DELETING" :
514+ raise ValueError (f"Session { session .id } is currently being deleted." )
515+
516+ current_revision = session_doc .get ("revision" , 0 )
517+
518+ if session ._storage_update_marker is not None :
519+ if session ._storage_update_marker != str (current_revision ):
520+ raise ValueError (
521+ "The session has been modified in storage since it was loaded. "
522+ "Please reload the session before appending more events."
523+ )
524+
440525 app_snap = (
441526 await app_ref .get (transaction = transaction ) if app_updates else None
442527 )
@@ -460,11 +545,19 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None:
460545 for k , v in session_updates .items ():
461546 session .state [k ] = v
462547
548+ new_revision = current_revision + 1
549+ session_only_state = {
550+ k : v
551+ for k , v in session .state .items ()
552+ if not k .startswith (State .APP_PREFIX )
553+ and not k .startswith (State .USER_PREFIX )
554+ }
463555 transaction .update (
464556 session_ref ,
465557 {
466- "state" : session . state ,
558+ "state" : session_only_state ,
467559 "updateTime" : firestore .SERVER_TIMESTAMP ,
560+ "revision" : new_revision ,
468561 },
469562 )
470563
@@ -483,26 +576,11 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None:
483576 },
484577 )
485578
579+ return new_revision
580+
486581 transaction_obj = self .client .transaction ()
487- await _append_txn (transaction_obj )
488- else :
489- batch = self .client .batch ()
490- event_id = event .id
491- event_ref = session_ref .collection (self .events_collection ).document (
492- event_id
493- )
494- event_data = event .model_dump (exclude_none = True , mode = "json" )
495- batch .set (
496- event_ref ,
497- {
498- "event_data" : event_data ,
499- "timestamp" : firestore .SERVER_TIMESTAMP ,
500- "appName" : session .app_name ,
501- "userId" : session .user_id ,
502- },
503- )
504- batch .update (session_ref , {"updateTime" : firestore .SERVER_TIMESTAMP })
505- await batch .commit ()
582+ new_revision_count = await _append_txn (transaction_obj )
583+ session ._storage_update_marker = str (new_revision_count )
506584
507585 await super ().append_event (session , event )
508586 return event
0 commit comments