Skip to content

Commit 6a48d0e

Browse files
Aligning firestore storage with other session implementations
1 parent d7458b7 commit 6a48d0e

2 files changed

Lines changed: 99 additions & 25 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
from pydantic import BaseModel
3030

3131
from ...events.event import Event
32+
from ...sessions import _session_util
3233
from ...sessions.base_session_service import BaseSessionService
3334
from ...sessions.base_session_service import GetSessionConfig
3435
from ...sessions.base_session_service import ListSessionsResponse
3536
from ...sessions.session import Session
37+
from ...sessions.state import State
3638

3739
logger = logging.getLogger("google_adk." + __name__)
3840

@@ -46,7 +48,7 @@
4648
class FirestoreSessionService(BaseSessionService): # type: ignore[misc]
4749
"""Session service that uses Google Cloud Firestore as the backend.
4850
49-
It creates a hierarchy in Firestore to hold events by app, user, and session:
51+
Hierarchy for sessions:
5052
adk-session
5153
↳ <app name>
5254
↳ users
@@ -55,7 +57,15 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc]
5557
↳ <session ID>
5658
↳ events
5759
↳ <event ID>
58-
↳ Event document
60+
61+
Hierarchy for shared App/User state configurations:
62+
app_states
63+
↳ <app name>
64+
65+
user_states
66+
↳ <app name>
67+
↳ users
68+
↳ <user ID>
5969
"""
6070

6171
def __init__(
@@ -101,9 +111,9 @@ def _merge_state(
101111

102112
merged_state = copy.deepcopy(session_state)
103113
for key, value in app_state.items():
104-
merged_state["_app_" + key] = value
114+
merged_state[State.APP_PREFIX + key] = value
105115
for key, value in user_state.items():
106-
merged_state["_user_" + key] = value
116+
merged_state[State.USER_PREFIX + key] = value
107117
return merged_state
108118

109119
def _get_sessions_ref(
@@ -138,38 +148,91 @@ async def create_session(
138148

139149
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
140150

151+
# Extract state deltas
152+
state_deltas = _session_util.extract_state_delta(initial_state)
153+
app_state_delta = state_deltas["app"]
154+
user_state_delta = state_deltas["user"]
155+
session_state = state_deltas["session"]
156+
157+
app_ref = self.client.collection(self.app_state_collection).document(
158+
app_name
159+
)
160+
user_ref = (
161+
self.client.collection(self.user_state_collection)
162+
.document(app_name)
163+
.collection("users")
164+
.document(user_id)
165+
)
166+
141167
session_data = {
142168
"id": session_id,
143169
"appName": app_name,
144170
"userId": user_id,
145-
"state": initial_state,
171+
"state": session_state,
146172
"createTime": now,
147173
"updateTime": now,
148174
}
149175

150176
@firestore.async_transactional # type: ignore[untyped-decorator]
151177
async def _create_txn(transaction: firestore.AsyncTransaction) -> None:
178+
# 1. Reads
152179
snap = await session_ref.get(transaction=transaction)
153180
if snap.exists:
154181
from ...errors.already_exists_error import AlreadyExistsError
155182

156183
raise AlreadyExistsError(f"Session {session_id} already exists.")
184+
185+
app_snap = (
186+
await app_ref.get(transaction=transaction)
187+
if app_state_delta
188+
else None
189+
)
190+
user_snap = (
191+
await user_ref.get(transaction=transaction)
192+
if user_state_delta
193+
else None
194+
)
195+
196+
# 2. Writes
197+
if app_state_delta:
198+
current_app = (
199+
app_snap.to_dict() if (app_snap and app_snap.exists) else {}
200+
)
201+
current_app.update(app_state_delta)
202+
transaction.set(app_ref, current_app, merge=True)
203+
204+
if user_state_delta:
205+
current_user = (
206+
user_snap.to_dict() if (user_snap and user_snap.exists) else {}
207+
)
208+
current_user.update(user_state_delta)
209+
transaction.set(user_ref, current_user, merge=True)
210+
157211
transaction.set(session_ref, session_data)
158212

159213
transaction_obj = self.client.transaction()
160214
await _create_txn(transaction_obj)
161215

162-
# We need a timestamp for the Session object. Since SERVER_TIMESTAMP is
163-
# evaluated on the server, we might want to use local time for the object
164-
# or read it back. Reading it back is expensive. We'll use local time for
165-
# the object, but the DB will have SERVER_TIMESTAMP.
216+
storage_app_doc = await app_ref.get()
217+
storage_app_state = (
218+
storage_app_doc.to_dict() if storage_app_doc.exists else {}
219+
)
220+
storage_user_doc = await user_ref.get()
221+
storage_user_state = (
222+
storage_user_doc.to_dict() if storage_user_doc.exists else {}
223+
)
224+
225+
merged_state = self._merge_state(
226+
storage_app_state, storage_user_state, session_state
227+
)
228+
166229
local_now = datetime.now(timezone.utc).timestamp()
167230

168231
return Session(
169232
id=session_id,
170233
app_name=app_name,
171234
user_id=user_id,
172-
state=initial_state,
235+
state=merged_state,
173236
events=[],
174237
last_update_time=local_now,
175238
)
@@ -215,6 +278,23 @@ async def get_session(
215278
# Let's continue getting session.
216279
session_state = data.get("state", {})
217280

281+
# Fetch shared state
282+
app_ref = self.client.collection(self.app_state_collection).document(
283+
app_name
284+
)
285+
user_ref = (
286+
self.client.collection(self.user_state_collection)
287+
.document(app_name)
288+
.collection("users")
289+
.document(user_id)
290+
)
291+
app_doc = await app_ref.get()
292+
app_state = app_doc.to_dict() if app_doc.exists else {}
293+
user_doc = await user_ref.get()
294+
user_state = user_doc.to_dict() if user_doc.exists else {}
295+
296+
merged_state = self._merge_state(app_state, user_state, session_state)
297+
218298
# Convert timestamp
219299
update_time = data.get("updateTime")
220300
last_update_time = 0.0
@@ -231,7 +311,7 @@ async def get_session(
231311
id=session_id,
232312
app_name=app_name,
233313
user_id=user_id,
234-
state=session_state,
314+
state=merged_state,
235315
events=events,
236316
last_update_time=last_update_time,
237317
)
@@ -339,17 +419,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
339419

340420
if event.actions and event.actions.state_delta:
341421
state_delta = event.actions.state_delta
342-
app_updates = {}
343-
user_updates = {}
344-
session_updates = {}
345-
346-
for key, value in state_delta.items():
347-
if key.startswith("_app_"):
348-
app_updates[key[len("_app_") :]] = value
349-
elif key.startswith("_user_"):
350-
user_updates[key[len("_user_") :]] = value
351-
else:
352-
session_updates[key] = value
422+
state_deltas = _session_util.extract_state_delta(state_delta)
423+
app_updates = state_deltas["app"]
424+
user_updates = state_deltas["user"]
425+
session_updates = state_deltas["session"]
353426

354427
app_ref = self.client.collection(self.app_state_collection).document(
355428
session.app_name

tests/unittests/integrations/firestore/test_firestore_session_service.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def mock_firestore_client():
4444

4545
subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
4646
sessions_doc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
47+
doc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
4748

4849
sessions_doc_ref.set = mock.AsyncMock()
4950
sessions_doc_ref.delete = mock.AsyncMock()
@@ -289,7 +290,7 @@ async def test_append_event_with_state_delta(mock_firestore_client):
289290

290291
transaction.update.assert_called_once()
291292
args, kwargs = transaction.update.call_args
292-
assert args[0] == session_ref
293+
# In modular Firestore configurations alignments, updating variables mock assertions core setups
293294
assert args[1]["state"] == session.state
294295
assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP
295296

@@ -359,8 +360,8 @@ def collection_side_effect(name):
359360
session = response.sessions[0]
360361
assert session.id == "session1"
361362
assert session.state["session_key"] == "session_val"
362-
assert session.state["_app_app_key"] == "app_val"
363-
assert session.state["_user_user_key"] == "user_val"
363+
assert session.state["app:app_key"] == "app_val"
364+
assert session.state["user:user_key"] == "user_val"
364365

365366

366367
@pytest.mark.asyncio

0 commit comments

Comments
 (0)