|
18 | 18 |
|
19 | 19 | from google.adk.events.event import Event |
20 | 20 | from google.adk.integrations.firestore.firestore_memory_service import FirestoreMemoryService |
| 21 | +from google.cloud.firestore_v1.base_query import FieldFilter |
21 | 22 | from google.genai import types |
22 | 23 | import pytest |
23 | 24 |
|
@@ -94,7 +95,38 @@ async def test_search_memory_with_results(mock_firestore_client): |
94 | 95 |
|
95 | 96 | mock_firestore_client.collection.assert_called_with("memories") |
96 | 97 | collection_ref = mock_firestore_client.collection.return_value |
97 | | - collection_ref.where.assert_called() |
| 98 | + |
| 99 | + assert collection_ref.where.call_count == 6 |
| 100 | + calls = collection_ref.where.call_args_list |
| 101 | + |
| 102 | + app_name_calls = 0 |
| 103 | + user_id_calls = 0 |
| 104 | + keyword_calls = 0 |
| 105 | + |
| 106 | + for call in calls: |
| 107 | + kwargs = call.kwargs |
| 108 | + filt = kwargs.get("filter") |
| 109 | + if filt: |
| 110 | + if ( |
| 111 | + filt.field_path == "appName" |
| 112 | + and filt.op_string == "==" |
| 113 | + and filt.value == app_name |
| 114 | + ): |
| 115 | + app_name_calls += 1 |
| 116 | + elif ( |
| 117 | + filt.field_path == "userId" |
| 118 | + and filt.op_string == "==" |
| 119 | + and filt.value == user_id |
| 120 | + ): |
| 121 | + user_id_calls += 1 |
| 122 | + elif filt.field_path == "keywords" and filt.op_string == "array_contains": |
| 123 | + |
| 124 | + if filt.value in ["quick", "fox"]: |
| 125 | + keyword_calls += 1 |
| 126 | + |
| 127 | + assert app_name_calls == 2 |
| 128 | + assert user_id_calls == 2 |
| 129 | + assert keyword_calls == 2 |
98 | 130 |
|
99 | 131 |
|
100 | 132 | @pytest.mark.asyncio |
@@ -167,3 +199,117 @@ async def test_search_memory_only_stop_words(mock_firestore_client): |
167 | 199 | ) |
168 | 200 | assert not response.memories |
169 | 201 | mock_firestore_client.collection.assert_not_called() |
| 202 | + |
| 203 | + |
| 204 | +def test_init_default_client(): |
| 205 | + with mock.patch("google.cloud.firestore.AsyncClient") as mock_client_class: |
| 206 | + mock_instance = mock.MagicMock() |
| 207 | + mock_client_class.return_value = mock_instance |
| 208 | + |
| 209 | + service = FirestoreMemoryService() |
| 210 | + |
| 211 | + mock_client_class.assert_called_once() |
| 212 | + assert service.client == mock_instance |
| 213 | + |
| 214 | + |
| 215 | +@pytest.mark.asyncio |
| 216 | +async def test_add_session_to_memory(mock_firestore_client): |
| 217 | + service = FirestoreMemoryService(client=mock_firestore_client) |
| 218 | + |
| 219 | + from google.adk.sessions.session import Session |
| 220 | + |
| 221 | + session = Session(id="test_session", app_name="test_app", user_id="test_user") |
| 222 | + |
| 223 | + content = types.Content(parts=[types.Part.from_text(text="quick brown fox")]) |
| 224 | + event = Event( |
| 225 | + invocation_id="test_inv", |
| 226 | + author="user", |
| 227 | + content=content, |
| 228 | + timestamp=1234567890.0, |
| 229 | + ) |
| 230 | + session.events.append(event) |
| 231 | + |
| 232 | + batch = mock.MagicMock() |
| 233 | + mock_firestore_client.batch.return_value = batch |
| 234 | + batch.commit = mock.AsyncMock() |
| 235 | + |
| 236 | + doc_ref = mock.MagicMock() |
| 237 | + mock_firestore_client.collection.return_value.document.return_value = doc_ref |
| 238 | + |
| 239 | + await service.add_session_to_memory(session) |
| 240 | + |
| 241 | + mock_firestore_client.batch.assert_called_once() |
| 242 | + mock_firestore_client.collection.assert_called_with("memories") |
| 243 | + batch.set.assert_called_once() |
| 244 | + batch.commit.assert_called_once() |
| 245 | + |
| 246 | + args, kwargs = batch.set.call_args |
| 247 | + assert args[0] == doc_ref |
| 248 | + data = args[1] |
| 249 | + assert data["appName"] == "test_app" |
| 250 | + assert data["userId"] == "test_user" |
| 251 | + assert "quick" in data["keywords"] |
| 252 | + assert data["author"] == "user" |
| 253 | + assert data["timestamp"] == 1234567890.0 |
| 254 | + |
| 255 | + |
| 256 | +@pytest.mark.asyncio |
| 257 | +async def test_add_session_to_memory_no_events(mock_firestore_client): |
| 258 | + service = FirestoreMemoryService(client=mock_firestore_client) |
| 259 | + |
| 260 | + from google.adk.sessions.session import Session |
| 261 | + |
| 262 | + session = Session(id="test_session", app_name="test_app", user_id="test_user") |
| 263 | + |
| 264 | + batch = mock.MagicMock() |
| 265 | + mock_firestore_client.batch.return_value = batch |
| 266 | + |
| 267 | + await service.add_session_to_memory(session) |
| 268 | + |
| 269 | + mock_firestore_client.batch.assert_called_once() |
| 270 | + batch.set.assert_not_called() |
| 271 | + batch.commit.assert_not_called() |
| 272 | + |
| 273 | + |
| 274 | +@pytest.mark.asyncio |
| 275 | +async def test_add_session_to_memory_no_keywords(mock_firestore_client): |
| 276 | + service = FirestoreMemoryService(client=mock_firestore_client) |
| 277 | + |
| 278 | + from google.adk.sessions.session import Session |
| 279 | + |
| 280 | + session = Session(id="test_session", app_name="test_app", user_id="test_user") |
| 281 | + |
| 282 | + content = types.Content(parts=[types.Part.from_text(text="the and or")]) |
| 283 | + event = Event(invocation_id="test_inv", author="user", content=content) |
| 284 | + session.events.append(event) |
| 285 | + |
| 286 | + batch = mock.MagicMock() |
| 287 | + mock_firestore_client.batch.return_value = batch |
| 288 | + |
| 289 | + await service.add_session_to_memory(session) |
| 290 | + |
| 291 | + mock_firestore_client.batch.assert_called_once() |
| 292 | + batch.set.assert_not_called() |
| 293 | + batch.commit.assert_not_called() |
| 294 | + |
| 295 | + |
| 296 | +@pytest.mark.asyncio |
| 297 | +async def test_add_session_to_memory_commit_error(mock_firestore_client): |
| 298 | + service = FirestoreMemoryService(client=mock_firestore_client) |
| 299 | + |
| 300 | + from google.adk.sessions.session import Session |
| 301 | + |
| 302 | + session = Session(id="test_session", app_name="test_app", user_id="test_user") |
| 303 | + |
| 304 | + content = types.Content(parts=[types.Part.from_text(text="quick brown fox")]) |
| 305 | + event = Event(invocation_id="test_inv", author="user", content=content) |
| 306 | + session.events.append(event) |
| 307 | + |
| 308 | + batch = mock.MagicMock() |
| 309 | + mock_firestore_client.batch.return_value = batch |
| 310 | + batch.commit = mock.AsyncMock( |
| 311 | + side_effect=Exception("Firestore commit failed") |
| 312 | + ) |
| 313 | + |
| 314 | + with pytest.raises(Exception, match="Firestore commit failed"): |
| 315 | + await service.add_session_to_memory(session) |
0 commit comments