Skip to content

Commit d7458b7

Browse files
Ensure memory generation does not go over the firestore batch limit
1 parent e05dd5b commit d7458b7

2 files changed

Lines changed: 41 additions & 3 deletions

File tree

src/google/adk/integrations/firestore/firestore_memory_service.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
async def add_session_to_memory(self, session: Session) -> None:
8787
"""Extracts keywords from session events and stores them in the memories collection."""
8888
batch = self.client.batch()
89-
has_updates = False
89+
count = 0
9090

9191
for event in session.events:
9292
if not event.content or not event.content.parts:
@@ -114,9 +114,13 @@ async def add_session_to_memory(self, session: Session) -> None:
114114
"timestamp": event.timestamp,
115115
},
116116
)
117-
has_updates = True
117+
count += 1
118+
if count >= 500:
119+
await batch.commit()
120+
batch = self.client.batch()
121+
count = 0
118122

119-
if has_updates:
123+
if count > 0:
120124
await batch.commit()
121125

122126
def _extract_keywords(self, text: str) -> set[str]:

tests/unittests/integrations/firestore/test_firestore_memory_service.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,37 @@ async def test_add_session_to_memory_commit_error(mock_firestore_client):
348348

349349
with pytest.raises(Exception, match="Firestore commit failed"):
350350
await service.add_session_to_memory(session)
351+
352+
353+
@pytest.mark.asyncio
354+
async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client):
355+
service = FirestoreMemoryService(client=mock_firestore_client)
356+
357+
from google.adk.sessions.session import Session
358+
359+
session = Session(id="test_session", app_name="test_app", user_id="test_user")
360+
361+
for i in range(501):
362+
content = types.Content(parts=[types.Part.from_text(text=f"event keyword {i}")])
363+
event = Event(
364+
invocation_id=f"test_inv_{i}",
365+
author="user",
366+
content=content,
367+
timestamp=1234567890.0 + i,
368+
)
369+
session.events.append(event)
370+
371+
batch1 = mock.MagicMock()
372+
batch2 = mock.MagicMock()
373+
batch1.commit = mock.AsyncMock()
374+
batch2.commit = mock.AsyncMock()
375+
mock_firestore_client.batch.side_effect = [batch1, batch2]
376+
377+
await service.add_session_to_memory(session)
378+
379+
assert mock_firestore_client.batch.call_count == 2
380+
assert batch1.set.call_count == 500
381+
batch1.commit.assert_called_once()
382+
assert batch2.set.call_count == 1
383+
batch2.commit.assert_called_once()
384+

0 commit comments

Comments
 (0)