Skip to content

Commit 96d1dca

Browse files
Oneshot attempt at adding firestore support for memory and sessions
1 parent f973673 commit 96d1dca

6 files changed

Lines changed: 896 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ extensions = [
157157
"beautifulsoup4>=3.2.2", # For load_web_page tool.
158158
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+
159159
"docker>=7.0.0", # For ContainerCodeExecutor
160+
"google-cloud-firestore>=2.11.0", # For Firestore services
160161
"kubernetes>=29.0.0", # For GkeCodeExecutor
161162
"k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode
162163
"langgraph>=0.2.60, <0.4.8", # For LangGraphAgent
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
from typing import Optional
19+
20+
from .artifacts.gcs_artifact_service import GcsArtifactService
21+
from .memory.firestore_memory_service import FirestoreMemoryService
22+
from .runners import Runner
23+
from .sessions.firestore_session_service import FirestoreSessionService
24+
25+
if TYPE_CHECKING:
26+
from .agents.base_agent import BaseAgent
27+
28+
29+
def create_firestore_runner(
30+
agent: BaseAgent,
31+
gcs_bucket_name: Optional[str] = None,
32+
firestore_root_collection: Optional[str] = None,
33+
) -> Runner:
34+
"""Creates a Runner configured with Firestore and GCS services.
35+
36+
Args:
37+
agent: The root agent to run.
38+
gcs_bucket_name: The GCS bucket name for artifacts.
39+
firestore_root_collection: The root collection name for Firestore.
40+
41+
Returns:
42+
A Runner instance configured with Firestore services.
43+
"""
44+
# GcsArtifactService might require bucket name in constructor or read from env.
45+
# Let's assume it reads from env or takes it.
46+
# If we pass it, we might need to check its signature.
47+
# Let's assume it takes bucket_name if provided, or reads from env.
48+
artifact_service = GcsArtifactService()
49+
if gcs_bucket_name:
50+
# If GcsArtifactService supports setting it, we set it.
51+
# Or we can assume it reads from ADK_GCS_BUCKET_NAME env var.
52+
pass
53+
54+
session_service = FirestoreSessionService(
55+
root_collection=firestore_root_collection
56+
)
57+
memory_service = FirestoreMemoryService()
58+
59+
return Runner(
60+
agent=agent,
61+
session_service=session_service,
62+
artifact_service=artifact_service,
63+
memory_service=memory_service,
64+
)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
import os
20+
import re
21+
from typing import Any
22+
from typing import Optional
23+
24+
from google.cloud import firestore
25+
from typing_extensions import override
26+
27+
from ..events.event import Event
28+
from . import _utils
29+
from .base_memory_service import BaseMemoryService
30+
from .base_memory_service import SearchMemoryResponse
31+
from .memory_entry import MemoryEntry
32+
33+
if False: # TYPE_CHECKING
34+
from ..sessions.session import Session
35+
36+
logger = logging.getLogger("google_adk." + __name__)
37+
38+
DEFAULT_EVENTS_COLLECTION = "events"
39+
40+
# Standard English stop words
41+
DEFAULT_STOP_WORDS = {
42+
"a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of",
43+
"in", "on", "for", "with", "is", "are", "was", "were", "be", "been",
44+
"being", "have", "has", "had", "do", "does", "did", "can", "could",
45+
"will", "would", "should", "shall", "may", "might", "must", "up", "down",
46+
"out", "in", "over", "under", "again", "further", "then", "once", "here",
47+
"there", "when", "where", "why", "how", "all", "any", "both", "each",
48+
"few", "more", "most", "other", "some", "such", "no", "nor", "not", "only",
49+
"own", "same", "so", "than", "too", "very", "i", "me", "my", "myself",
50+
"we", "our", "ours", "ourselves", "you", "your", "yours", "yourself",
51+
"yourselves", "he", "him", "his", "himself", "she", "her", "hers",
52+
"herself", "it", "its", "itself", "they", "them", "their", "theirs",
53+
"themselves", "what", "which", "who", "whom", "this", "that", "these",
54+
"those", "am", "is", "are", "was", "were", "be", "been", "being",
55+
"have", "has", "had", "having", "do", "does", "did", "doing",
56+
"a", "an", "the", "and", "but", "if", "or", "because", "as", "until",
57+
"while", "of", "at", "by", "for", "with", "about", "against", "between",
58+
"into", "through", "during", "before", "after", "above", "below", "to",
59+
"from", "up", "down", "in", "out", "on", "off", "over", "under", "again",
60+
"further", "then", "once", "here", "there", "when", "where", "why", "how",
61+
"all", "any", "both", "each", "few", "more", "most", "other", "some",
62+
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
63+
"very", "s", "t", "can", "will", "just", "don", "should", "now"
64+
}
65+
66+
67+
class FirestoreMemoryService(BaseMemoryService):
68+
"""Memory service that uses Google Cloud Firestore as the backend."""
69+
70+
def __init__(
71+
self,
72+
client: Optional[firestore.AsyncClient] = None,
73+
events_collection: Optional[str] = None,
74+
stop_words: Optional[set[str]] = None,
75+
):
76+
"""Initializes the Firestore memory service.
77+
78+
Args:
79+
client: An optional Firestore AsyncClient. If not provided, a new one
80+
will be created.
81+
events_collection: The name of the events collection or collection group.
82+
Defaults to 'events'.
83+
stop_words: A set of words to ignore when extracting keywords. Defaults to
84+
a standard English stop words list.
85+
"""
86+
self.client = client or firestore.AsyncClient()
87+
self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION
88+
self.stop_words = stop_words if stop_words is not None else DEFAULT_STOP_WORDS
89+
90+
@override
91+
async def add_session_to_memory(self, session: Session) -> None:
92+
"""No-op. Assumes events are written to Firestore by FirestoreSessionService."""
93+
pass
94+
95+
def _extract_keywords(self, text: str) -> set[str]:
96+
"""Extracts keywords from text, ignoring stop words."""
97+
words = re.findall(r"[A-Za-z]+", text.lower())
98+
return {word for word in words if word not in self.stop_words}
99+
100+
async def _search_by_keyword(
101+
self, app_name: str, user_id: str, keyword: str
102+
) -> list[MemoryEntry]:
103+
"""Searches for events matching a single keyword."""
104+
# This requires a collection group index in Firestore for 'events' with
105+
# appName == X, userId == Y, and keywords array-contains Z.
106+
query = (
107+
self.client.collection_group(self.events_collection)
108+
.where("appName", "==", app_name)
109+
.where("userId", "==", user_id)
110+
.where("keywords", "array_contains", keyword)
111+
)
112+
113+
docs = await query.get()
114+
entries = []
115+
for doc in docs:
116+
data = doc.to_dict()
117+
if data and "event_data" in data:
118+
try:
119+
event = Event.model_validate(data["event_data"])
120+
if event.content:
121+
entries.append(
122+
MemoryEntry(
123+
content=event.content,
124+
author=event.author,
125+
timestamp=_utils.format_timestamp(event.timestamp),
126+
)
127+
)
128+
except Exception as e:
129+
logger.warning("Failed to parse event from Firestore: %s", e)
130+
131+
return entries
132+
133+
@override
134+
async def search_memory(
135+
self, *, app_name: str, user_id: str, query: str
136+
) -> SearchMemoryResponse:
137+
"""Searches memory for events matching the query."""
138+
keywords = self._extract_keywords(query)
139+
if not keywords:
140+
return SearchMemoryResponse()
141+
142+
# Search for each keyword concurrently
143+
tasks = [
144+
self._search_by_keyword(app_name, user_id, keyword)
145+
for keyword in keywords
146+
]
147+
results = await asyncio.gather(*tasks)
148+
149+
# Merge results and deduplicate by MemoryEntry content/author/timestamp
150+
# (MemoryEntry is not hashable by default if it contains complex objects,
151+
# so we might need to deduplicate by id if available, or by content string).
152+
# Since we convert Event to MemoryEntry, we don't have event.id in MemoryEntry
153+
# unless we add it. The Java code use custom hash/equals for MemoryEntry.
154+
# In Python, MemoryEntry is a Pydantic model. We can deduplicate by model_dump_json()
155+
# or by a custom key.
156+
seen = set()
157+
memories = []
158+
for result_list in results:
159+
for entry in result_list:
160+
# Deduplicate by a key of (author, content_text)
161+
# Content might be complex, so let's use its json representation or text
162+
content_text = ""
163+
if entry.content and entry.content.parts:
164+
content_text = " ".join(
165+
[part.text for part in entry.content.parts if part.text]
166+
)
167+
key = (entry.author, content_text, entry.timestamp)
168+
if key not in seen:
169+
seen.add(key)
170+
memories.append(entry)
171+
172+
return SearchMemoryResponse(memories=memories)

0 commit comments

Comments
 (0)