1919import logging
2020import os
2121from typing import Any
22+ from typing import cast
2223from typing import Optional
24+ from typing import TYPE_CHECKING
25+
26+ if TYPE_CHECKING :
27+ from google .cloud import firestore
2328
2429from pydantic import BaseModel
2530
3843DEFAULT_USER_STATE_COLLECTION = "user_states"
3944
4045
41- class FirestoreSessionService (BaseSessionService ):
46+ class FirestoreSessionService (BaseSessionService ): # type: ignore[misc]
4247 """Session service that uses Google Cloud Firestore as the backend."""
4348
4449 def __init__ (
@@ -309,16 +314,16 @@ async def _update_app_state_transactional(
309314 app_name
310315 )
311316
312- @firestore .async_transactional
313- async def _txn (transaction ) :
317+ @firestore .async_transactional # type: ignore[untyped-decorator]
318+ async def _txn (transaction : firestore . AsyncTransaction ) -> dict [ str , Any ] :
314319 snap = await doc_ref .get (transaction = transaction )
315320 current = snap .to_dict () if snap .exists else {}
316321 current .update (delta )
317322 transaction .set (doc_ref , current , merge = True )
318323 return current
319324
320325 transaction = self .client .transaction ()
321- return await _txn (transaction )
326+ return cast ( dict [ str , Any ], await _txn (transaction ) )
322327
323328 async def _update_user_state_transactional (
324329 self , app_name : str , user_id : str , delta : dict [str , Any ]
@@ -333,16 +338,16 @@ async def _update_user_state_transactional(
333338 .document (user_id )
334339 )
335340
336- @firestore .async_transactional
337- async def _txn (transaction ) :
341+ @firestore .async_transactional # type: ignore[untyped-decorator]
342+ async def _txn (transaction : firestore . AsyncTransaction ) -> dict [ str , Any ] :
338343 snap = await doc_ref .get (transaction = transaction )
339344 current = snap .to_dict () if snap .exists else {}
340345 current .update (delta )
341346 transaction .set (doc_ref , current , merge = True )
342347 return current
343348
344349 transaction = self .client .transaction ()
345- return await _txn (transaction )
350+ return cast ( dict [ str , Any ], await _txn (transaction ) )
346351
347352 async def append_event (self , session : Session , event : Event ) -> Event :
348353 """Appends an event to a session in Firestore."""
0 commit comments