Skip to content

Commit 7d0b4cc

Browse files
Addressing mypy errors
1 parent 28a571e commit 7d0b4cc

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
import logging
2020
import os
2121
from typing import Any
22+
from typing import cast
2223
from typing import Optional
24+
from typing import TYPE_CHECKING
25+
26+
if TYPE_CHECKING:
27+
from google.cloud import firestore
2328

2429
from pydantic import BaseModel
2530

@@ -38,7 +43,7 @@
3843
DEFAULT_USER_STATE_COLLECTION = "user_states"
3944

4045

41-
class FirestoreSessionService(BaseSessionService):
46+
class FirestoreSessionService(BaseSessionService): # type: ignore[misc]
4247
"""Session service that uses Google Cloud Firestore as the backend."""
4348

4449
def __init__(
@@ -309,16 +314,16 @@ async def _update_app_state_transactional(
309314
app_name
310315
)
311316

312-
@firestore.async_transactional
313-
async def _txn(transaction):
317+
@firestore.async_transactional # type: ignore[untyped-decorator]
318+
async def _txn(transaction: firestore.AsyncTransaction) -> dict[str, Any]:
314319
snap = await doc_ref.get(transaction=transaction)
315320
current = snap.to_dict() if snap.exists else {}
316321
current.update(delta)
317322
transaction.set(doc_ref, current, merge=True)
318323
return current
319324

320325
transaction = self.client.transaction()
321-
return await _txn(transaction)
326+
return cast(dict[str, Any], await _txn(transaction))
322327

323328
async def _update_user_state_transactional(
324329
self, app_name: str, user_id: str, delta: dict[str, Any]
@@ -333,16 +338,16 @@ async def _update_user_state_transactional(
333338
.document(user_id)
334339
)
335340

336-
@firestore.async_transactional
337-
async def _txn(transaction):
341+
@firestore.async_transactional # type: ignore[untyped-decorator]
342+
async def _txn(transaction: firestore.AsyncTransaction) -> dict[str, Any]:
338343
snap = await doc_ref.get(transaction=transaction)
339344
current = snap.to_dict() if snap.exists else {}
340345
current.update(delta)
341346
transaction.set(doc_ref, current, merge=True)
342347
return current
343348

344349
transaction = self.client.transaction()
345-
return await _txn(transaction)
350+
return cast(dict[str, Any], await _txn(transaction))
346351

347352
async def append_event(self, session: Session, event: Event) -> Event:
348353
"""Appends an event to a session in Firestore."""

0 commit comments

Comments
 (0)