Skip to content

Commit bf47b5c

Browse files
Updating session service to address various concerns
These updates are derived from @anmolg1997 PR on the community repo: google/adk-python-community#104
1 parent 49580f6 commit bf47b5c

4 files changed

Lines changed: 430 additions & 22 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 115 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Any
2222
from typing import Optional
2323

24-
from google.cloud import firestore
2524
from pydantic import BaseModel
2625

2726
from ...events.event import Event
@@ -55,6 +54,14 @@ def __init__(
5554
root_collection: The root collection name. Defaults to 'adk-session' or
5655
the value of ADK_FIRESTORE_ROOT_COLLECTION env var.
5756
"""
57+
try:
58+
from google.cloud import firestore
59+
except ImportError as e:
60+
raise ImportError(
61+
"FirestoreSessionService requires google-cloud-firestore. "
62+
"Install it with: pip install google-cloud-firestore"
63+
) from e
64+
5865
self.client = client or firestore.AsyncClient()
5966
self.root_collection = (
6067
root_collection
@@ -66,6 +73,22 @@ def __init__(
6673
self.app_state_collection = DEFAULT_APP_STATE_COLLECTION
6774
self.user_state_collection = DEFAULT_USER_STATE_COLLECTION
6875

76+
@staticmethod
77+
def _merge_state(
78+
app_state: dict[str, Any],
79+
user_state: dict[str, Any],
80+
session_state: dict[str, Any],
81+
) -> dict[str, Any]:
82+
"""Merge app, user, and session states into a single state dictionary."""
83+
import copy
84+
85+
merged_state = copy.deepcopy(session_state)
86+
for key, value in app_state.items():
87+
merged_state["_app_" + key] = value
88+
for key, value in user_state.items():
89+
merged_state["_user_" + key] = value
90+
return merged_state
91+
6992
def _get_sessions_ref(
7093
self, user_id: str
7194
) -> firestore.AsyncCollectionReference:
@@ -84,6 +107,7 @@ async def create_session(
84107
session_id: Optional[str] = None,
85108
) -> Session:
86109
"""Creates a new session in Firestore."""
110+
from google.cloud import firestore
87111
if not session_id:
88112
from ...platform import uuid as platform_uuid
89113

@@ -202,17 +226,50 @@ async def list_sessions(
202226
)
203227
docs = await query.get()
204228

229+
# Fetch shared state once
230+
app_ref = self.client.collection(self.app_state_collection).document(
231+
app_name
232+
)
233+
app_doc = await app_ref.get()
234+
app_state = app_doc.to_dict() if app_doc.exists else {}
235+
236+
user_states_map = {}
237+
if user_id:
238+
user_ref = (
239+
self.client.collection(self.user_state_collection)
240+
.document(app_name)
241+
.collection("users")
242+
.document(user_id)
243+
)
244+
user_doc = await user_ref.get()
245+
if user_doc.exists:
246+
user_states_map[user_id] = user_doc.to_dict()
247+
else:
248+
users_ref = (
249+
self.client.collection(self.user_state_collection)
250+
.document(app_name)
251+
.collection("users")
252+
)
253+
users_docs = await users_ref.get()
254+
for u_doc in users_docs:
255+
user_states_map[u_doc.id] = u_doc.to_dict()
256+
205257
sessions = []
206258
for doc in docs:
207259
data = doc.to_dict()
208260
if data:
261+
u_id = data["userId"]
262+
s_state = data.get("state", {})
263+
u_state = user_states_map.get(u_id, {})
264+
merged = self._merge_state(app_state, u_state, s_state)
265+
209266
sessions.append(
210267
Session(
211268
id=data["id"],
212269
app_name=data["appName"],
213270
user_id=data["userId"],
214-
state={}, # Empty state for listing
215-
events=[], # Empty events for listing
271+
state=merged,
272+
events=[],
216273
last_update_time=0.0,
217274
)
218275
)
@@ -226,17 +283,65 @@ async def delete_session(
226283
session_ref = self._get_sessions_ref(user_id).document(session_id)
227284

228285
events_ref = session_ref.collection(self.events_collection)
229-
events_docs = await events_ref.get()
230-
286+
231287
batch = self.client.batch()
232-
for event_doc in events_docs:
288+
count = 0
289+
async for event_doc in events_ref.stream():
233290
batch.delete(event_doc.reference)
234-
await batch.commit()
291+
count += 1
292+
if count >= 500:
293+
await batch.commit()
294+
batch = self.client.batch()
295+
count = 0
296+
if count > 0:
297+
await batch.commit()
235298

236299
await session_ref.delete()
237300

301+
async def _update_app_state_transactional(
302+
self, app_name: str, delta: dict[str, Any]
303+
) -> dict[str, Any]:
304+
"""Atomically applies delta to app state inside a transaction."""
305+
from google.cloud import firestore
306+
doc_ref = self.client.collection(self.app_state_collection).document(app_name)
307+
308+
@firestore.async_transactional
309+
async def _txn(transaction):
310+
snap = await doc_ref.get(transaction=transaction)
311+
current = snap.to_dict() if snap.exists else {}
312+
current.update(delta)
313+
transaction.set(doc_ref, current, merge=True)
314+
return current
315+
316+
transaction = self.client.transaction()
317+
return await _txn(transaction)
318+
319+
async def _update_user_state_transactional(
320+
self, app_name: str, user_id: str, delta: dict[str, Any]
321+
) -> dict[str, Any]:
322+
"""Atomically applies delta to user state inside a transaction."""
323+
from google.cloud import firestore
324+
doc_ref = (
325+
self.client.collection(self.user_state_collection)
326+
.document(app_name)
327+
.collection("users")
328+
.document(user_id)
329+
)
330+
331+
@firestore.async_transactional
332+
async def _txn(transaction):
333+
snap = await doc_ref.get(transaction=transaction)
334+
current = snap.to_dict() if snap.exists else {}
335+
current.update(delta)
336+
transaction.set(doc_ref, current, merge=True)
337+
return current
338+
339+
transaction = self.client.transaction()
340+
return await _txn(transaction)
341+
238342
async def append_event(self, session: Session, event: Event) -> Event:
239343
"""Appends an event to a session in Firestore."""
344+
from google.cloud import firestore
240345
if event.partial:
241346
return event
242347

@@ -259,26 +364,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
259364
else:
260365
session_updates[key] = value
261366

262-
batch = self.client.batch()
263-
264367
if app_updates:
265-
app_ref = self.client.collection(self.app_state_collection).document(
266-
session.app_name
267-
)
268-
batch.set(app_ref, app_updates, merge=True)
368+
await self._update_app_state_transactional(session.app_name, app_updates)
269369

270370
if user_updates:
271-
user_ref = (
272-
self.client.collection(self.user_state_collection)
273-
.document(session.app_name)
274-
.collection("users")
275-
.document(session.user_id)
276-
)
277-
batch.set(user_ref, user_updates, merge=True)
371+
await self._update_user_state_transactional(session.app_name, session.user_id, user_updates)
278372

279373
for k, v in session_updates.items():
280374
session.state[k] = v
281375

376+
batch = self.client.batch()
282377
batch.update(
283378
session_ref,
284379
{

tests/unittests/integrations/firestore/test_firestore_database_runner.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,22 @@ def test_create_firestore_runner_missing_bucket(mock_agent, monkeypatch):
7171
ValueError, match="Required property 'ADK_GCS_BUCKET_NAME' is not set"
7272
):
7373
create_firestore_runner(mock_agent)
74+
75+
76+
def test_create_firestore_runner_with_root_collection(mock_agent, monkeypatch):
77+
monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "test_bucket")
78+
79+
with (
80+
mock.patch(
81+
"google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService"
82+
) as mock_session,
83+
mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"),
84+
mock.patch("google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService"),
85+
):
86+
runner = create_firestore_runner(
87+
mock_agent, firestore_root_collection="custom_collection"
88+
)
89+
90+
assert runner is not None
91+
mock_session.assert_called_once_with(root_collection="custom_collection")
92+

tests/unittests/integrations/firestore/test_firestore_memory_service.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,75 @@ async def test_search_memory_with_results(mock_firestore_client):
9393
mock_firestore_client.collection_group.assert_called_with("events")
9494
collection_ref = mock_firestore_client.collection_group.return_value
9595
collection_ref.where.assert_called()
96+
97+
98+
@pytest.mark.asyncio
99+
async def test_search_memory_deduplication(mock_firestore_client):
100+
service = FirestoreMemoryService(client=mock_firestore_client)
101+
app_name = "test_app"
102+
user_id = "test_user"
103+
query = "quick fox"
104+
105+
event1 = Event(
106+
invocation_id="test_inv1",
107+
author="user",
108+
content=types.Content(parts=[types.Part(text="quick fox jumps")]),
109+
timestamp=1234567890.0,
110+
)
111+
event2 = Event(
112+
invocation_id="test_inv2",
113+
author="user",
114+
content=types.Content(parts=[types.Part(text="quick fox jumps")]),
115+
timestamp=1234567890.0,
116+
)
117+
118+
doc_snapshot1 = mock.MagicMock()
119+
doc_snapshot1.to_dict.return_value = {
120+
"event_data": event1.model_dump(exclude_none=True, mode="json")
121+
}
122+
123+
doc_snapshot2 = mock.MagicMock()
124+
doc_snapshot2.to_dict.return_value = {
125+
"event_data": event2.model_dump(exclude_none=True, mode="json")
126+
}
127+
128+
get_mock = mock.AsyncMock(side_effect=[[doc_snapshot1], [doc_snapshot2]])
129+
130+
mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get = get_mock
131+
132+
response = await service.search_memory(
133+
app_name=app_name, user_id=user_id, query=query
134+
)
135+
136+
assert response.memories
137+
assert len(response.memories) == 1
138+
assert response.memories[0].author == "user"
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_search_memory_parsing_error(mock_firestore_client, caplog):
143+
service = FirestoreMemoryService(client=mock_firestore_client)
144+
app_name = "test_app"
145+
user_id = "test_user"
146+
query = "quick"
147+
148+
doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0]
149+
doc_snapshot.to_dict.return_value = {"event_data": "invalid_data"}
150+
151+
response = await service.search_memory(
152+
app_name=app_name, user_id=user_id, query=query
153+
)
154+
155+
assert not response.memories
156+
assert "Failed to parse event from Firestore" in caplog.text
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_search_memory_only_stop_words(mock_firestore_client):
161+
service = FirestoreMemoryService(client=mock_firestore_client)
162+
response = await service.search_memory(
163+
app_name="test_app", user_id="test_user", query="the and or"
164+
)
165+
assert not response.memories
166+
mock_firestore_client.collection_group.assert_not_called()
167+

0 commit comments

Comments
 (0)