Skip to content

Commit 9c19c02

Browse files
Adding app to document hierarchy
1 parent b652669 commit 9c19c02

2 files changed

Lines changed: 64 additions & 40 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@
4646
class FirestoreSessionService(BaseSessionService): # type: ignore[misc]
4747
"""Session service that uses Google Cloud Firestore as the backend.
4848
49-
It creates a hierarchy in Firestore to hold events by user and session:
49+
It creates a hierarchy in Firestore to hold events by app, user, and session:
5050
adk-session
51-
↳ <user ID>
52-
↳ sessions
53-
↳ <session ID>
54-
↳ events
55-
↳ <event ID>
56-
↳ Event document
51+
↳ <app name>
52+
↳ users
53+
↳ <user ID>
54+
↳ sessions
55+
↳ <session ID>
56+
↳ events
57+
↳ <event ID>
58+
↳ Event document
5759
"""
5860

5961
def __init__(
@@ -105,10 +107,12 @@ def _merge_state(
105107
return merged_state
106108

107109
def _get_sessions_ref(
108-
self, user_id: str
110+
self, app_name: str, user_id: str
109111
) -> firestore.AsyncCollectionReference:
110112
return (
111113
self.client.collection(self.root_collection)
114+
.document(app_name)
115+
.collection("users")
112116
.document(user_id)
113117
.collection(self.sessions_collection)
114118
)
@@ -132,7 +136,7 @@ async def create_session(
132136
initial_state = state or {}
133137
now = firestore.SERVER_TIMESTAMP
134138

135-
session_ref = self._get_sessions_ref(user_id).document(session_id)
139+
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
136140

137141
# Check if session already exists
138142
doc = await session_ref.get()
@@ -176,7 +180,7 @@ async def get_session(
176180
config: Optional[GetSessionConfig] = None,
177181
) -> Optional[Session]:
178182
"""Gets a session from Firestore."""
179-
session_ref = self._get_sessions_ref(user_id).document(session_id)
183+
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
180184
doc = await session_ref.get()
181185

182186
if not doc.exists:
@@ -234,7 +238,7 @@ async def list_sessions(
234238
) -> ListSessionsResponse:
235239
"""Lists sessions from Firestore."""
236240
if user_id:
237-
query = self._get_sessions_ref(user_id).where("appName", "==", app_name)
241+
query = self._get_sessions_ref(app_name, user_id).where("appName", "==", app_name)
238242
docs = await query.get()
239243
else:
240244
query = self.client.collection_group(self.sessions_collection).where(
@@ -296,7 +300,7 @@ async def delete_session(
296300
self, *, app_name: str, user_id: str, session_id: str
297301
) -> None:
298302
"""Deletes a session and its events from Firestore."""
299-
session_ref = self._get_sessions_ref(user_id).document(session_id)
303+
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
300304

301305
events_ref = session_ref.collection(self.events_collection)
302306

@@ -369,7 +373,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
369373
self._apply_temp_state(session, event)
370374
event = self._trim_temp_delta_state(event)
371375

372-
session_ref = self._get_sessions_ref(session.user_id).document(session.id)
376+
session_ref = self._get_sessions_ref(session.app_name, session.user_id).document(session.id)
373377

374378
if event.actions and event.actions.state_delta:
375379
state_delta = event.actions.state_delta

tests/unittests/integrations/firestore/test_firestore_session_service.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,31 +28,35 @@ def mock_firestore_client():
2828
doc_ref = mock.MagicMock()
2929
subcollection_ref = mock.MagicMock()
3030
subdoc_ref = mock.MagicMock()
31+
sessions_coll_ref = mock.MagicMock()
32+
sessions_doc_ref = mock.MagicMock()
3133

3234
client.collection.return_value = collection_ref
3335
collection_ref.document.return_value = doc_ref
3436
doc_ref.collection.return_value = subcollection_ref
3537
subcollection_ref.document.return_value = subdoc_ref
38+
subdoc_ref.collection.return_value = sessions_coll_ref
39+
sessions_coll_ref.document.return_value = sessions_doc_ref
3640

3741
doc_snapshot = mock.MagicMock()
3842
doc_snapshot.exists = False
3943
doc_snapshot.to_dict.return_value = {}
4044

41-
doc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
4245
subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
46+
sessions_doc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
4347

44-
subdoc_ref.set = mock.AsyncMock()
45-
subdoc_ref.delete = mock.AsyncMock()
48+
sessions_doc_ref.set = mock.AsyncMock()
49+
sessions_doc_ref.delete = mock.AsyncMock()
4650

4751
events_collection_ref = mock.MagicMock()
48-
subdoc_ref.collection.return_value = events_collection_ref
52+
sessions_doc_ref.collection.return_value = events_collection_ref
4953
events_collection_ref.order_by.return_value = events_collection_ref
5054
events_collection_ref.where.return_value = events_collection_ref
5155
events_collection_ref.limit_to_last.return_value = events_collection_ref
5256
events_collection_ref.get = mock.AsyncMock(return_value=[])
5357

54-
subcollection_ref.get = mock.AsyncMock(return_value=[])
55-
subcollection_ref.where.return_value = subcollection_ref
58+
sessions_coll_ref.get = mock.AsyncMock(return_value=[])
59+
sessions_coll_ref.where.return_value = sessions_coll_ref
5660

5761
client.collection_group.return_value = collection_ref
5862

@@ -92,11 +96,15 @@ async def test_create_session(mock_firestore_client):
9296
assert session.id
9397

9498
mock_firestore_client.collection.assert_called_once_with("adk-session")
95-
collection_ref = mock_firestore_client.collection.return_value
96-
collection_ref.document.assert_called_once_with(user_id)
97-
doc_ref = collection_ref.document.return_value
98-
doc_ref.collection.assert_called_once_with("sessions")
99-
sessions_ref = doc_ref.collection.return_value
99+
root_coll = mock_firestore_client.collection.return_value
100+
root_coll.document.assert_called_once_with(app_name)
101+
app_ref = root_coll.document.return_value
102+
app_ref.collection.assert_called_once_with("users")
103+
users_coll = app_ref.collection.return_value
104+
users_coll.document.assert_called_once_with(user_id)
105+
user_ref = users_coll.document.return_value
106+
user_ref.collection.assert_called_once_with("sessions")
107+
sessions_ref = user_ref.collection.return_value
100108
sessions_ref.document.assert_called_once_with(session.id)
101109
session_doc_ref = sessions_ref.document.return_value
102110
from google.cloud import firestore
@@ -125,11 +133,15 @@ async def test_get_session_not_found(mock_firestore_client):
125133
assert session is None
126134

127135
mock_firestore_client.collection.assert_called_with("adk-session")
128-
collection_ref = mock_firestore_client.collection.return_value
129-
collection_ref.document.assert_called_with(user_id)
130-
doc_ref = collection_ref.document.return_value
131-
doc_ref.collection.assert_called_with("sessions")
132-
sessions_ref = doc_ref.collection.return_value
136+
root_coll = mock_firestore_client.collection.return_value
137+
root_coll.document.assert_called_with(app_name)
138+
app_ref = root_coll.document.return_value
139+
app_ref.collection.assert_called_with("users")
140+
users_coll = app_ref.collection.return_value
141+
users_coll.document.assert_called_with(user_id)
142+
user_ref = users_coll.document.return_value
143+
user_ref.collection.assert_called_with("sessions")
144+
sessions_ref = user_ref.collection.return_value
133145
sessions_ref.document.assert_called_with(session_id)
134146

135147

@@ -153,7 +165,7 @@ async def test_get_session_found(mock_firestore_client):
153165
}
154166

155167
events_collection_ref = (
156-
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
168+
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
157169
)
158170
event_doc = mock.MagicMock()
159171
event_doc.to_dict.return_value = {
@@ -180,7 +192,7 @@ async def test_delete_session(mock_firestore_client):
180192
session_id = "test_session"
181193

182194
events_ref = (
183-
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
195+
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
184196
)
185197
event_doc = mock.AsyncMock()
186198

@@ -201,7 +213,7 @@ async def to_async_iter(iterable):
201213
batch.commit.assert_called_once()
202214

203215
session_doc_ref = (
204-
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value
216+
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value
205217
)
206218
session_doc_ref.delete.assert_called_once()
207219

@@ -334,10 +346,14 @@ def collection_side_effect(name):
334346
users_coll.document.return_value = user_doc_ref
335347
user_doc_ref.get = mock.AsyncMock(return_value=user_doc)
336348

337-
user_doc_in_sessions = mock.MagicMock()
338-
sessions_coll.document.return_value = user_doc_in_sessions
349+
app_doc_in_root = mock.MagicMock()
350+
sessions_coll.document.return_value = app_doc_in_root
351+
users_coll = mock.MagicMock()
352+
app_doc_in_root.collection.return_value = users_coll
353+
user_doc_in_users = mock.MagicMock()
354+
users_coll.document.return_value = user_doc_in_users
339355
sessions_subcoll = mock.MagicMock()
340-
user_doc_in_sessions.collection.return_value = sessions_subcoll
356+
user_doc_in_users.collection.return_value = sessions_subcoll
341357
sessions_query = mock.MagicMock()
342358
sessions_subcoll.where.return_value = sessions_query
343359
sessions_query.get = mock.AsyncMock(return_value=[session_doc])
@@ -448,7 +464,7 @@ async def test_get_session_with_config(mock_firestore_client):
448464
}
449465

450466
events_collection_ref = (
451-
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
467+
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
452468
)
453469

454470
from google.adk.sessions.base_session_service import GetSessionConfig
@@ -471,7 +487,7 @@ async def test_delete_session_batching(mock_firestore_client):
471487
session_id = "test_session"
472488

473489
events_ref = (
474-
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
490+
mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value
475491
)
476492

477493
dummy_docs = [mock.MagicMock() for _ in range(501)]
@@ -631,10 +647,14 @@ def collection_side_effect(name):
631647
users_coll.document.return_value = user_doc_ref
632648
user_doc_ref.get = mock.AsyncMock(return_value=user_doc)
633649

634-
user_doc_in_sessions = mock.MagicMock()
635-
sessions_coll.document.return_value = user_doc_in_sessions
650+
app_doc_in_root = mock.MagicMock()
651+
sessions_coll.document.return_value = app_doc_in_root
652+
users_coll = mock.MagicMock()
653+
app_doc_in_root.collection.return_value = users_coll
654+
user_doc_in_users = mock.MagicMock()
655+
users_coll.document.return_value = user_doc_in_users
636656
sessions_subcoll = mock.MagicMock()
637-
user_doc_in_sessions.collection.return_value = sessions_subcoll
657+
user_doc_in_users.collection.return_value = sessions_subcoll
638658
sessions_query = mock.MagicMock()
639659
sessions_subcoll.where.return_value = sessions_query
640660
sessions_query.get = mock.AsyncMock(return_value=[session_doc])

0 commit comments

Comments
 (0)