Skip to content

Commit b652669

Browse files
Much improved unit tests
1 parent e18a1f8 commit b652669

2 files changed

Lines changed: 360 additions & 3 deletions

File tree

tests/unittests/integrations/firestore/test_firestore_memory_service.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from google.adk.events.event import Event
2020
from google.adk.integrations.firestore.firestore_memory_service import FirestoreMemoryService
21+
from google.cloud.firestore_v1.base_query import FieldFilter
2122
from google.genai import types
2223
import pytest
2324

@@ -94,7 +95,38 @@ async def test_search_memory_with_results(mock_firestore_client):
9495

9596
mock_firestore_client.collection.assert_called_with("memories")
9697
collection_ref = mock_firestore_client.collection.return_value
97-
collection_ref.where.assert_called()
98+
99+
assert collection_ref.where.call_count == 6
100+
calls = collection_ref.where.call_args_list
101+
102+
app_name_calls = 0
103+
user_id_calls = 0
104+
keyword_calls = 0
105+
106+
for call in calls:
107+
kwargs = call.kwargs
108+
filt = kwargs.get("filter")
109+
if filt:
110+
if (
111+
filt.field_path == "appName"
112+
and filt.op_string == "=="
113+
and filt.value == app_name
114+
):
115+
app_name_calls += 1
116+
elif (
117+
filt.field_path == "userId"
118+
and filt.op_string == "=="
119+
and filt.value == user_id
120+
):
121+
user_id_calls += 1
122+
elif filt.field_path == "keywords" and filt.op_string == "array_contains":
123+
124+
if filt.value in ["quick", "fox"]:
125+
keyword_calls += 1
126+
127+
assert app_name_calls == 2
128+
assert user_id_calls == 2
129+
assert keyword_calls == 2
98130

99131

100132
@pytest.mark.asyncio
@@ -167,3 +199,117 @@ async def test_search_memory_only_stop_words(mock_firestore_client):
167199
)
168200
assert not response.memories
169201
mock_firestore_client.collection.assert_not_called()
202+
203+
204+
def test_init_default_client():
205+
with mock.patch("google.cloud.firestore.AsyncClient") as mock_client_class:
206+
mock_instance = mock.MagicMock()
207+
mock_client_class.return_value = mock_instance
208+
209+
service = FirestoreMemoryService()
210+
211+
mock_client_class.assert_called_once()
212+
assert service.client == mock_instance
213+
214+
215+
@pytest.mark.asyncio
216+
async def test_add_session_to_memory(mock_firestore_client):
217+
service = FirestoreMemoryService(client=mock_firestore_client)
218+
219+
from google.adk.sessions.session import Session
220+
221+
session = Session(id="test_session", app_name="test_app", user_id="test_user")
222+
223+
content = types.Content(parts=[types.Part.from_text(text="quick brown fox")])
224+
event = Event(
225+
invocation_id="test_inv",
226+
author="user",
227+
content=content,
228+
timestamp=1234567890.0,
229+
)
230+
session.events.append(event)
231+
232+
batch = mock.MagicMock()
233+
mock_firestore_client.batch.return_value = batch
234+
batch.commit = mock.AsyncMock()
235+
236+
doc_ref = mock.MagicMock()
237+
mock_firestore_client.collection.return_value.document.return_value = doc_ref
238+
239+
await service.add_session_to_memory(session)
240+
241+
mock_firestore_client.batch.assert_called_once()
242+
mock_firestore_client.collection.assert_called_with("memories")
243+
batch.set.assert_called_once()
244+
batch.commit.assert_called_once()
245+
246+
args, kwargs = batch.set.call_args
247+
assert args[0] == doc_ref
248+
data = args[1]
249+
assert data["appName"] == "test_app"
250+
assert data["userId"] == "test_user"
251+
assert "quick" in data["keywords"]
252+
assert data["author"] == "user"
253+
assert data["timestamp"] == 1234567890.0
254+
255+
256+
@pytest.mark.asyncio
257+
async def test_add_session_to_memory_no_events(mock_firestore_client):
258+
service = FirestoreMemoryService(client=mock_firestore_client)
259+
260+
from google.adk.sessions.session import Session
261+
262+
session = Session(id="test_session", app_name="test_app", user_id="test_user")
263+
264+
batch = mock.MagicMock()
265+
mock_firestore_client.batch.return_value = batch
266+
267+
await service.add_session_to_memory(session)
268+
269+
mock_firestore_client.batch.assert_called_once()
270+
batch.set.assert_not_called()
271+
batch.commit.assert_not_called()
272+
273+
274+
@pytest.mark.asyncio
275+
async def test_add_session_to_memory_no_keywords(mock_firestore_client):
276+
service = FirestoreMemoryService(client=mock_firestore_client)
277+
278+
from google.adk.sessions.session import Session
279+
280+
session = Session(id="test_session", app_name="test_app", user_id="test_user")
281+
282+
content = types.Content(parts=[types.Part.from_text(text="the and or")])
283+
event = Event(invocation_id="test_inv", author="user", content=content)
284+
session.events.append(event)
285+
286+
batch = mock.MagicMock()
287+
mock_firestore_client.batch.return_value = batch
288+
289+
await service.add_session_to_memory(session)
290+
291+
mock_firestore_client.batch.assert_called_once()
292+
batch.set.assert_not_called()
293+
batch.commit.assert_not_called()
294+
295+
296+
@pytest.mark.asyncio
297+
async def test_add_session_to_memory_commit_error(mock_firestore_client):
298+
service = FirestoreMemoryService(client=mock_firestore_client)
299+
300+
from google.adk.sessions.session import Session
301+
302+
session = Session(id="test_session", app_name="test_app", user_id="test_user")
303+
304+
content = types.Content(parts=[types.Part.from_text(text="quick brown fox")])
305+
event = Event(invocation_id="test_inv", author="user", content=content)
306+
session.events.append(event)
307+
308+
batch = mock.MagicMock()
309+
mock_firestore_client.batch.return_value = batch
310+
batch.commit = mock.AsyncMock(
311+
side_effect=Exception("Firestore commit failed")
312+
)
313+
314+
with pytest.raises(Exception, match="Firestore commit failed"):
315+
await service.add_session_to_memory(session)

0 commit comments

Comments
 (0)