Skip to content

Commit ca724d2

Browse files
Hardening firestore against concurrent modification
1 parent 6a48d0e commit ca724d2

3 files changed

Lines changed: 188 additions & 75 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 116 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,20 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
18+
from contextlib import asynccontextmanager
1719
from datetime import datetime
1820
from datetime import timezone
1921
import logging
2022
import os
2123
from typing import Any
24+
from typing import AsyncIterator
2225
from typing import cast
2326
from typing import Optional
2427
from typing import TYPE_CHECKING
2528

29+
_SessionLockKey = tuple[str, str, str]
30+
2631
if TYPE_CHECKING:
2732
from google.cloud import firestore
2833

@@ -96,10 +101,40 @@ def __init__(
96101
or DEFAULT_ROOT_COLLECTION
97102
)
98103
self.sessions_collection = DEFAULT_SESSIONS_COLLECTION
104+
105+
# Per-session locks used to serialize append_event calls in this process.
106+
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
107+
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
108+
self._session_locks_guard = asyncio.Lock()
99109
self.events_collection = DEFAULT_EVENTS_COLLECTION
100110
self.app_state_collection = DEFAULT_APP_STATE_COLLECTION
101111
self.user_state_collection = DEFAULT_USER_STATE_COLLECTION
102112

113+
@asynccontextmanager
114+
async def _with_session_lock(
115+
self, *, app_name: str, user_id: str, session_id: str
116+
) -> AsyncIterator[None]:
117+
"""Serializes event appends for the same session within this process."""
118+
lock_key = (app_name, user_id, session_id)
119+
async with self._session_locks_guard:
120+
lock = self._session_locks.get(lock_key, asyncio.Lock())
121+
self._session_locks[lock_key] = lock
122+
self._session_lock_ref_count[lock_key] = (
123+
self._session_lock_ref_count.get(lock_key, 0) + 1
124+
)
125+
126+
try:
127+
async with lock:
128+
yield
129+
finally:
130+
async with self._session_locks_guard:
131+
remaining = self._session_lock_ref_count.get(lock_key, 0) - 1
132+
if remaining <= 0 and not lock.locked():
133+
self._session_lock_ref_count.pop(lock_key, None)
134+
self._session_locks.pop(lock_key, None)
135+
else:
136+
self._session_lock_ref_count[lock_key] = remaining
137+
103138
@staticmethod
104139
def _merge_state(
105140
app_state: dict[str, Any],
@@ -171,6 +206,7 @@ async def create_session(
171206
"state": session_state,
172207
"createTime": now,
173208
"updateTime": now,
209+
"revision": 1,
174210
}
175211

176212
@firestore.async_transactional # type: ignore[untyped-decorator]
@@ -228,14 +264,16 @@ async def _create_txn(transaction: firestore.AsyncTransaction) -> None:
228264

229265
local_now = datetime.now(timezone.utc).timestamp()
230266

231-
return Session(
267+
session = Session(
232268
id=session_id,
233269
app_name=app_name,
234270
user_id=user_id,
235271
state=merged_state,
236272
events=[],
237273
last_update_time=local_now,
238274
)
275+
session._storage_update_marker = "1"
276+
return session
239277

240278
async def get_session(
241279
self,
@@ -307,14 +345,19 @@ async def get_session(
307345
except (ValueError, TypeError):
308346
pass
309347

310-
return Session(
348+
current_revision = data.get("revision", 0)
349+
session = Session(
311350
id=session_id,
312351
app_name=app_name,
313352
user_id=user_id,
314353
state=merged_state,
315354
events=events,
316355
last_update_time=last_update_time,
317356
)
357+
session._storage_update_marker = (
358+
str(current_revision) if current_revision > 0 else None
359+
)
360+
return session
318361

319362
async def list_sessions(
320363
self, *, app_name: str, user_id: Optional[str] = None
@@ -385,8 +428,24 @@ async def delete_session(
385428
self, *, app_name: str, user_id: str, session_id: str
386429
) -> None:
387430
"""Deletes a session and its events from Firestore."""
431+
from google.cloud import firestore
432+
388433
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
389434

435+
@firestore.async_transactional # type: ignore[untyped-decorator]
436+
async def _mark_deleting_txn(
437+
transaction: firestore.AsyncTransaction,
438+
) -> None:
439+
snap = await session_ref.get(transaction=transaction)
440+
if snap.exists:
441+
transaction.update(session_ref, {"status": "DELETING"})
442+
443+
try:
444+
transaction_obj = self.client.transaction()
445+
await _mark_deleting_txn(transaction_obj)
446+
except Exception:
447+
pass
448+
390449
events_ref = session_ref.collection(self.events_collection)
391450

392451
batch = self.client.batch()
@@ -417,26 +476,52 @@ async def append_event(self, session: Session, event: Event) -> Event:
417476
session.app_name, session.user_id
418477
).document(session.id)
419478

420-
if event.actions and event.actions.state_delta:
421-
state_delta = event.actions.state_delta
422-
state_deltas = _session_util.extract_state_delta(state_delta)
423-
app_updates = state_deltas["app"]
424-
user_updates = state_deltas["user"]
425-
session_updates = state_deltas["session"]
479+
state_delta = (
480+
event.actions.state_delta
481+
if event.actions and event.actions.state_delta
482+
else {}
483+
)
484+
state_deltas = _session_util.extract_state_delta(state_delta)
485+
app_updates = state_deltas["app"]
486+
user_updates = state_deltas["user"]
487+
session_updates = state_deltas["session"]
426488

427-
app_ref = self.client.collection(self.app_state_collection).document(
428-
session.app_name
429-
)
430-
user_ref = (
431-
self.client.collection(self.user_state_collection)
432-
.document(session.app_name)
433-
.collection("users")
434-
.document(session.user_id)
435-
)
489+
app_ref = self.client.collection(self.app_state_collection).document(
490+
session.app_name
491+
)
492+
user_ref = (
493+
self.client.collection(self.user_state_collection)
494+
.document(session.app_name)
495+
.collection("users")
496+
.document(session.user_id)
497+
)
498+
499+
async with self._with_session_lock(
500+
app_name=session.app_name,
501+
user_id=session.user_id,
502+
session_id=session.id,
503+
):
436504

437505
@firestore.async_transactional # type: ignore[untyped-decorator]
438-
async def _append_txn(transaction: firestore.AsyncTransaction) -> None:
506+
async def _append_txn(transaction: firestore.AsyncTransaction) -> int:
439507
# 1. Reads
508+
session_snap = await session_ref.get(transaction=transaction)
509+
if not session_snap.exists:
510+
raise ValueError(f"Session {session.id} not found.")
511+
512+
session_doc = session_snap.to_dict() or {}
513+
if session_doc.get("status") == "DELETING":
514+
raise ValueError(f"Session {session.id} is currently being deleted.")
515+
516+
current_revision = session_doc.get("revision", 0)
517+
518+
if session._storage_update_marker is not None:
519+
if session._storage_update_marker != str(current_revision):
520+
raise ValueError(
521+
"The session has been modified in storage since it was loaded. "
522+
"Please reload the session before appending more events."
523+
)
524+
440525
app_snap = (
441526
await app_ref.get(transaction=transaction) if app_updates else None
442527
)
@@ -460,11 +545,19 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None:
460545
for k, v in session_updates.items():
461546
session.state[k] = v
462547

548+
new_revision = current_revision + 1
549+
session_only_state = {
550+
k: v
551+
for k, v in session.state.items()
552+
if not k.startswith(State.APP_PREFIX)
553+
and not k.startswith(State.USER_PREFIX)
554+
}
463555
transaction.update(
464556
session_ref,
465557
{
466-
"state": session.state,
558+
"state": session_only_state,
467559
"updateTime": firestore.SERVER_TIMESTAMP,
560+
"revision": new_revision,
468561
},
469562
)
470563

@@ -483,26 +576,11 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None:
483576
},
484577
)
485578

579+
return new_revision
580+
486581
transaction_obj = self.client.transaction()
487-
await _append_txn(transaction_obj)
488-
else:
489-
batch = self.client.batch()
490-
event_id = event.id
491-
event_ref = session_ref.collection(self.events_collection).document(
492-
event_id
493-
)
494-
event_data = event.model_dump(exclude_none=True, mode="json")
495-
batch.set(
496-
event_ref,
497-
{
498-
"event_data": event_data,
499-
"timestamp": firestore.SERVER_TIMESTAMP,
500-
"appName": session.app_name,
501-
"userId": session.user_id,
502-
},
503-
)
504-
batch.update(session_ref, {"updateTime": firestore.SERVER_TIMESTAMP})
505-
await batch.commit()
582+
new_revision_count = await _append_txn(transaction_obj)
583+
session._storage_update_marker = str(new_revision_count)
506584

507585
await super().append_event(session, event)
508586
return event

tests/unittests/integrations/firestore/test_firestore_memory_service.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,21 +208,24 @@ async def test_search_memory_partial_failures(mock_firestore_client, caplog):
208208
user_id = "test_user"
209209
query = "fox quick"
210210

211-
coll_ref = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value
212-
211+
coll_ref = (
212+
mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value
213+
)
214+
213215
doc_snapshot = mock.MagicMock()
214216
doc_snapshot.to_dict.return_value = {
215217
"content": {"parts": [{"text": "quick response"}]},
216218
"author": "user",
217-
"timestamp": 1234567890.0
219+
"timestamp": 1234567890.0,
218220
}
219221

220222
call_count = 0
223+
221224
async def mock_get():
222225
nonlocal call_count
223226
call_count += 1
224227
if call_count == 1:
225-
raise ValueError("Mock generic network failure standalone")
228+
raise ValueError("Mock generic network failure standalone")
226229
return [doc_snapshot]
227230

228231
coll_ref.get = mock.AsyncMock(side_effect=mock_get)
@@ -359,7 +362,9 @@ async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client):
359362
session = Session(id="test_session", app_name="test_app", user_id="test_user")
360363

361364
for i in range(501):
362-
content = types.Content(parts=[types.Part.from_text(text=f"event keyword {i}")])
365+
content = types.Content(
366+
parts=[types.Part.from_text(text=f"event keyword {i}")]
367+
)
363368
event = Event(
364369
invocation_id=f"test_inv_{i}",
365370
author="user",
@@ -381,4 +386,3 @@ async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client):
381386
batch1.commit.assert_called_once()
382387
assert batch2.set.call_count == 1
383388
batch2.commit.assert_called_once()
384-

0 commit comments

Comments
 (0)