|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import asyncio |
| 18 | +import logging |
| 19 | +import os |
| 20 | +import re |
| 21 | +from typing import Any |
| 22 | +from typing import Optional |
| 23 | + |
| 24 | +from google.cloud import firestore |
| 25 | +from typing_extensions import override |
| 26 | + |
| 27 | +from ..events.event import Event |
| 28 | +from . import _utils |
| 29 | +from .base_memory_service import BaseMemoryService |
| 30 | +from .base_memory_service import SearchMemoryResponse |
| 31 | +from .memory_entry import MemoryEntry |
| 32 | + |
| 33 | +if False: # TYPE_CHECKING |
| 34 | + from ..sessions.session import Session |
| 35 | + |
| 36 | +logger = logging.getLogger("google_adk." + __name__) |
| 37 | + |
| 38 | +DEFAULT_EVENTS_COLLECTION = "events" |
| 39 | + |
| 40 | +# Standard English stop words |
| 41 | +DEFAULT_STOP_WORDS = { |
| 42 | + "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", |
| 43 | + "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", |
| 44 | + "being", "have", "has", "had", "do", "does", "did", "can", "could", |
| 45 | + "will", "would", "should", "shall", "may", "might", "must", "up", "down", |
| 46 | + "out", "in", "over", "under", "again", "further", "then", "once", "here", |
| 47 | + "there", "when", "where", "why", "how", "all", "any", "both", "each", |
| 48 | + "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", |
| 49 | + "own", "same", "so", "than", "too", "very", "i", "me", "my", "myself", |
| 50 | + "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", |
| 51 | + "yourselves", "he", "him", "his", "himself", "she", "her", "hers", |
| 52 | + "herself", "it", "its", "itself", "they", "them", "their", "theirs", |
| 53 | + "themselves", "what", "which", "who", "whom", "this", "that", "these", |
| 54 | + "those", "am", "is", "are", "was", "were", "be", "been", "being", |
| 55 | + "have", "has", "had", "having", "do", "does", "did", "doing", |
| 56 | + "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", |
| 57 | + "while", "of", "at", "by", "for", "with", "about", "against", "between", |
| 58 | + "into", "through", "during", "before", "after", "above", "below", "to", |
| 59 | + "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", |
| 60 | + "further", "then", "once", "here", "there", "when", "where", "why", "how", |
| 61 | + "all", "any", "both", "each", "few", "more", "most", "other", "some", |
| 62 | + "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", |
| 63 | + "very", "s", "t", "can", "will", "just", "don", "should", "now" |
| 64 | +} |
| 65 | + |
| 66 | + |
| 67 | +class FirestoreMemoryService(BaseMemoryService): |
| 68 | + """Memory service that uses Google Cloud Firestore as the backend.""" |
| 69 | + |
| 70 | + def __init__( |
| 71 | + self, |
| 72 | + client: Optional[firestore.AsyncClient] = None, |
| 73 | + events_collection: Optional[str] = None, |
| 74 | + stop_words: Optional[set[str]] = None, |
| 75 | + ): |
| 76 | + """Initializes the Firestore memory service. |
| 77 | +
|
| 78 | + Args: |
| 79 | + client: An optional Firestore AsyncClient. If not provided, a new one |
| 80 | + will be created. |
| 81 | + events_collection: The name of the events collection or collection group. |
| 82 | + Defaults to 'events'. |
| 83 | + stop_words: A set of words to ignore when extracting keywords. Defaults to |
| 84 | + a standard English stop words list. |
| 85 | + """ |
| 86 | + self.client = client or firestore.AsyncClient() |
| 87 | + self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION |
| 88 | + self.stop_words = stop_words if stop_words is not None else DEFAULT_STOP_WORDS |
| 89 | + |
| 90 | + @override |
| 91 | + async def add_session_to_memory(self, session: Session) -> None: |
| 92 | + """No-op. Assumes events are written to Firestore by FirestoreSessionService.""" |
| 93 | + pass |
| 94 | + |
| 95 | + def _extract_keywords(self, text: str) -> set[str]: |
| 96 | + """Extracts keywords from text, ignoring stop words.""" |
| 97 | + words = re.findall(r"[A-Za-z]+", text.lower()) |
| 98 | + return {word for word in words if word not in self.stop_words} |
| 99 | + |
| 100 | + async def _search_by_keyword( |
| 101 | + self, app_name: str, user_id: str, keyword: str |
| 102 | + ) -> list[MemoryEntry]: |
| 103 | + """Searches for events matching a single keyword.""" |
| 104 | + # This requires a collection group index in Firestore for 'events' with |
| 105 | + # appName == X, userId == Y, and keywords array-contains Z. |
| 106 | + query = ( |
| 107 | + self.client.collection_group(self.events_collection) |
| 108 | + .where("appName", "==", app_name) |
| 109 | + .where("userId", "==", user_id) |
| 110 | + .where("keywords", "array_contains", keyword) |
| 111 | + ) |
| 112 | + |
| 113 | + docs = await query.get() |
| 114 | + entries = [] |
| 115 | + for doc in docs: |
| 116 | + data = doc.to_dict() |
| 117 | + if data and "event_data" in data: |
| 118 | + try: |
| 119 | + event = Event.model_validate(data["event_data"]) |
| 120 | + if event.content: |
| 121 | + entries.append( |
| 122 | + MemoryEntry( |
| 123 | + content=event.content, |
| 124 | + author=event.author, |
| 125 | + timestamp=_utils.format_timestamp(event.timestamp), |
| 126 | + ) |
| 127 | + ) |
| 128 | + except Exception as e: |
| 129 | + logger.warning("Failed to parse event from Firestore: %s", e) |
| 130 | + |
| 131 | + return entries |
| 132 | + |
| 133 | + @override |
| 134 | + async def search_memory( |
| 135 | + self, *, app_name: str, user_id: str, query: str |
| 136 | + ) -> SearchMemoryResponse: |
| 137 | + """Searches memory for events matching the query.""" |
| 138 | + keywords = self._extract_keywords(query) |
| 139 | + if not keywords: |
| 140 | + return SearchMemoryResponse() |
| 141 | + |
| 142 | + # Search for each keyword concurrently |
| 143 | + tasks = [ |
| 144 | + self._search_by_keyword(app_name, user_id, keyword) |
| 145 | + for keyword in keywords |
| 146 | + ] |
| 147 | + results = await asyncio.gather(*tasks) |
| 148 | + |
| 149 | + # Merge results and deduplicate by MemoryEntry content/author/timestamp |
| 150 | + # (MemoryEntry is not hashable by default if it contains complex objects, |
| 151 | + # so we might need to deduplicate by id if available, or by content string). |
| 152 | + # Since we convert Event to MemoryEntry, we don't have event.id in MemoryEntry |
| 153 | + # unless we add it. The Java code use custom hash/equals for MemoryEntry. |
| 154 | + # In Python, MemoryEntry is a Pydantic model. We can deduplicate by model_dump_json() |
| 155 | + # or by a custom key. |
| 156 | + seen = set() |
| 157 | + memories = [] |
| 158 | + for result_list in results: |
| 159 | + for entry in result_list: |
| 160 | + # Deduplicate by a key of (author, content_text) |
| 161 | + # Content might be complex, so let's use its json representation or text |
| 162 | + content_text = "" |
| 163 | + if entry.content and entry.content.parts: |
| 164 | + content_text = " ".join( |
| 165 | + [part.text for part in entry.content.parts if part.text] |
| 166 | + ) |
| 167 | + key = (entry.author, content_text, entry.timestamp) |
| 168 | + if key not in seen: |
| 169 | + seen.add(key) |
| 170 | + memories.append(entry) |
| 171 | + |
| 172 | + return SearchMemoryResponse(memories=memories) |
0 commit comments