Skip to content

Commit 2daa18d

Browse files
Append all event data in a single transaction
1 parent 486c21a commit 2daa18d

2 files changed

Lines changed: 84 additions & 72 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,6 @@ async def create_session(
138138

139139
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
140140

141-
# Check if session already exists
142-
doc = await session_ref.get()
143-
if doc.exists:
144-
from ...errors.already_exists_error import AlreadyExistsError
145-
146-
raise AlreadyExistsError(f"Session {session_id} already exists.")
147-
148141
session_data = {
149142
"id": session_id,
150143
"appName": app_name,
@@ -154,7 +147,16 @@ async def create_session(
154147
"updateTime": now,
155148
}
156149

157-
await session_ref.set(session_data)
150+
@firestore.async_transactional # type: ignore[untyped-decorator]
151+
async def _create_txn(transaction: firestore.AsyncTransaction) -> None:
152+
snap = await session_ref.get(transaction=transaction)
153+
if snap.exists:
154+
from ...errors.already_exists_error import AlreadyExistsError
155+
raise AlreadyExistsError(f"Session {session_id} already exists.")
156+
transaction.set(session_ref, session_data)
157+
158+
transaction_obj = self.client.transaction()
159+
await _create_txn(transaction_obj)
158160

159161
# We need a timestamp for the Session object. Since SERVER_TIMESTAMP is
160162
# evaluated on the server, we might want to use local time for the object
@@ -318,6 +320,7 @@ async def delete_session(
318320

319321
await session_ref.delete()
320322

323+
321324
async def _update_app_state_transactional(
322325
self, app_name: str, delta: dict[str, Any]
323326
) -> dict[str, Any]:
@@ -389,44 +392,57 @@ async def append_event(self, session: Session, event: Event) -> Event:
389392
else:
390393
session_updates[key] = value
391394

392-
if app_updates:
393-
await self._update_app_state_transactional(
394-
session.app_name, app_updates
395-
)
395+
app_ref = self.client.collection(self.app_state_collection).document(session.app_name)
396+
user_ref = (
397+
self.client.collection(self.user_state_collection)
398+
.document(session.app_name)
399+
.collection("users")
400+
.document(session.user_id)
401+
)
396402

397-
if user_updates:
398-
await self._update_user_state_transactional(
399-
session.app_name, session.user_id, user_updates
403+
@firestore.async_transactional # type: ignore[untyped-decorator]
404+
async def _append_txn(transaction: firestore.AsyncTransaction) -> None:
405+
# 1. Reads
406+
app_snap = await app_ref.get(transaction=transaction) if app_updates else None
407+
user_snap = await user_ref.get(transaction=transaction) if user_updates else None
408+
409+
# 2. Writes
410+
if app_updates and app_snap is not None:
411+
current_app = app_snap.to_dict() if app_snap.exists else {}
412+
current_app.update(app_updates)
413+
transaction.set(app_ref, current_app, merge=True)
414+
415+
if user_updates and user_snap is not None:
416+
current_user = user_snap.to_dict() if user_snap.exists else {}
417+
current_user.update(user_updates)
418+
transaction.set(user_ref, current_user, merge=True)
419+
420+
for k, v in session_updates.items():
421+
session.state[k] = v
422+
423+
transaction.update(
424+
session_ref,
425+
{
426+
"state": session.state,
427+
"updateTime": firestore.SERVER_TIMESTAMP,
428+
},
400429
)
401430

402-
for k, v in session_updates.items():
403-
session.state[k] = v
404-
405-
batch = self.client.batch()
406-
batch.update(
407-
session_ref,
408-
{
409-
"state": session.state,
410-
"updateTime": firestore.SERVER_TIMESTAMP,
411-
},
412-
)
413-
414-
event_id = event.id
415-
event_ref = session_ref.collection(self.events_collection).document(
416-
event_id
417-
)
418-
event_data = event.model_dump(exclude_none=True, mode="json")
419-
batch.set(
420-
event_ref,
421-
{
422-
"event_data": event_data,
423-
"timestamp": firestore.SERVER_TIMESTAMP,
424-
"appName": session.app_name,
425-
"userId": session.user_id,
426-
},
427-
)
431+
event_id = event.id
432+
event_ref = session_ref.collection(self.events_collection).document(event_id)
433+
event_data = event.model_dump(exclude_none=True, mode="json")
434+
transaction.set(
435+
event_ref,
436+
{
437+
"event_data": event_data,
438+
"timestamp": firestore.SERVER_TIMESTAMP,
439+
"appName": session.app_name,
440+
"userId": session.user_id,
441+
},
442+
)
428443

429-
await batch.commit()
444+
transaction_obj = self.client.transaction()
445+
await _append_txn(transaction_obj)
430446
else:
431447
batch = self.client.batch()
432448
event_id = event.id

tests/unittests/integrations/firestore/test_firestore_session_service.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ async def test_create_session(mock_firestore_client):
8989
app_name = "test_app"
9090
user_id = "test_user"
9191

92-
session = await service.create_session(app_name=app_name, user_id=user_id)
92+
with mock.patch("google.cloud.firestore.async_transactional", lambda x: x):
93+
session = await service.create_session(app_name=app_name, user_id=user_id)
9394

9495
assert session.app_name == app_name
9596
assert session.user_id == user_id
@@ -109,14 +110,16 @@ async def test_create_session(mock_firestore_client):
109110
session_doc_ref = sessions_ref.document.return_value
110111
from google.cloud import firestore
111112

112-
session_doc_ref.set.assert_called_once()
113-
args, kwargs = session_doc_ref.set.call_args
114-
assert args[0]["id"] == session.id
115-
assert args[0]["appName"] == app_name
116-
assert args[0]["userId"] == user_id
117-
assert args[0]["state"] == {}
118-
assert args[0]["createTime"] == firestore.SERVER_TIMESTAMP
119-
assert args[0]["updateTime"] == firestore.SERVER_TIMESTAMP
113+
transaction = mock_firestore_client.transaction.return_value
114+
transaction.set.assert_called_once()
115+
args, kwargs = transaction.set.call_args
116+
assert args[0] == session_doc_ref
117+
assert args[1]["id"] == session.id
118+
assert args[1]["appName"] == app_name
119+
assert args[1]["userId"] == user_id
120+
assert args[1]["state"] == {}
121+
assert args[1]["createTime"] == firestore.SERVER_TIMESTAMP
122+
assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP
120123

121124

122125
@pytest.mark.asyncio
@@ -228,7 +231,8 @@ async def test_append_event(mock_firestore_client):
228231
session = Session(id="test_session", app_name=app_name, user_id=user_id)
229232
event = Event(invocation_id="test_inv", author="user")
230233

231-
await service.append_event(session, event)
234+
with mock.patch("google.cloud.firestore.async_transactional", lambda x: x):
235+
await service.append_event(session, event)
232236

233237
from google.cloud import firestore
234238

@@ -273,31 +277,22 @@ async def test_append_event_with_state_delta(mock_firestore_client):
273277
service._update_app_state_transactional = mock.AsyncMock()
274278
service._update_user_state_transactional = mock.AsyncMock()
275279

276-
await service.append_event(session, event)
277-
278-
mock_firestore_client.batch.assert_called_once()
279-
service._update_app_state_transactional.assert_called_once_with(
280-
"test_app", {"my_key": "app_val"}
281-
)
282-
service._update_user_state_transactional.assert_called_once_with(
283-
"test_app", "test_user", {"my_key": "user_val"}
284-
)
285-
286-
batch = mock_firestore_client.batch.return_value
280+
with mock.patch("google.cloud.firestore.async_transactional", lambda x: x):
281+
await service.append_event(session, event)
287282

288-
batch.set.assert_called()
283+
transaction = mock_firestore_client.transaction.return_value
284+
transaction.set.assert_called()
289285

290286
assert session.state["session_key"] == "session_val"
291287

292288
from google.cloud import firestore
293289

294-
batch.update.assert_called_once()
295-
args, kwargs = batch.update.call_args
290+
transaction.update.assert_called_once()
291+
args, kwargs = transaction.update.call_args
292+
assert args[0] == session_ref
296293
assert args[1]["state"] == session.state
297294
assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP
298295

299-
batch.commit.assert_called_once()
300-
301296

302297
@pytest.mark.asyncio
303298
async def test_list_sessions_with_user_id(mock_firestore_client):
@@ -497,10 +492,11 @@ async def test_create_session_already_exists(mock_firestore_client):
497492

498493
from google.adk.errors.already_exists_error import AlreadyExistsError
499494

500-
with pytest.raises(AlreadyExistsError):
501-
await service.create_session(
502-
app_name=app_name, user_id=user_id, session_id="existing_id"
503-
)
495+
with mock.patch("google.cloud.firestore.async_transactional", lambda x: x):
496+
with pytest.raises(AlreadyExistsError):
497+
await service.create_session(
498+
app_name=app_name, user_id=user_id, session_id="existing_id"
499+
)
504500

505501

506502
@pytest.mark.asyncio

0 commit comments

Comments
 (0)