88import threading
99from collections .abc import Iterable
1010from functools import partial
11- from typing import IO
1211from typing import Final
1312
1413import databento_dbn
2019from databento .common .enums import ReconnectPolicy
2120from databento .common .error import BentoError
2221from databento .common .publishers import Dataset
23- from databento .common .types import DBNRecord
22+ from databento .common .types import ClientStream , DBNRecord
2423from databento .common .types import ExceptionCallback
2524from databento .common .types import ReconnectCallback
2625from databento .common .types import RecordCallback
@@ -148,6 +147,12 @@ class SessionMetadata:
148147 def __bool__ (self ) -> bool :
149148 return self .data is not None
150149
150+ @property
151+ def has_ts_out (self ) -> bool :
152+ if self .data is None :
153+ return False
154+ return self .data .ts_out
155+
151156 def check (self , other : databento_dbn .Metadata ) -> None :
152157 """
153158 Verify the Metadata is compatible with another Metadata message. This
@@ -191,7 +196,7 @@ def __init__(
191196 dataset : Dataset | str ,
192197 dbn_queue : DBNQueue ,
193198 user_callbacks : list [tuple [RecordCallback , ExceptionCallback | None ]],
194- user_streams : list [tuple [ IO [ bytes ], ExceptionCallback | None ] ],
199+ user_streams : list [ClientStream ],
195200 loop : asyncio .AbstractEventLoop ,
196201 metadata : SessionMetadata ,
197202 ts_out : bool = False ,
@@ -210,21 +215,15 @@ def received_metadata(self, metadata: databento_dbn.Metadata) -> None:
210215 if self ._metadata :
211216 self ._metadata .check (metadata )
212217 else :
213- metadata_bytes = metadata .encode ()
214- for stream , exc_callback in self ._user_streams :
218+ for stream in self ._user_streams :
215219 try :
216- stream .write (metadata_bytes )
220+ stream .write (metadata . encode () )
217221 except Exception as exc :
218- stream_name = getattr (stream , "name" , str (stream ))
219222 logger .error (
220- "error writing %d bytes to `%s` stream" ,
221- len (metadata_bytes ),
222- stream_name ,
223+ "error writing metadata to `%s` stream" ,
224+ stream .stream_name ,
223225 exc_info = exc ,
224226 )
225- if exc_callback is not None :
226- exc_callback (exc )
227-
228227 self ._metadata .data = metadata
229228 return super ().received_metadata (metadata )
230229
@@ -252,26 +251,20 @@ def _dispatch_callbacks(self, record: DBNRecord) -> None:
252251 exc_callback (exc )
253252
254253 def _dispatch_writes (self , record : DBNRecord ) -> None :
255- if hasattr (record , "ts_out" ):
256- ts_out_bytes = struct .pack ("Q" , record .ts_out )
257- else :
258- ts_out_bytes = b""
259-
260- record_bytes = bytes (record ) + ts_out_bytes
261-
262- for stream , exc_callback in self ._user_streams :
254+ record_bytes = bytes (record )
255+ ts_out_bytes = struct .pack ("Q" , record .ts_out ) if self ._metadata .has_ts_out else b""
256+ for stream in self ._user_streams :
263257 try :
264258 stream .write (record_bytes )
259+ stream .write (ts_out_bytes )
265260 except Exception as exc :
266- stream_name = getattr (stream , "name" , str (stream ))
267261 logger .error (
268- "error writing %d bytes to `%s` stream" ,
269- len (record_bytes ),
270- stream_name ,
262+ "error writing %s record (%d bytes) to `%s` stream" ,
263+ type (record ).__name__ ,
264+ len (record_bytes ) + len (ts_out_bytes ),
265+ stream .stream_name ,
271266 exc_info = exc ,
272267 )
273- if exc_callback is not None :
274- exc_callback (exc )
275268
276269 def _queue_for_iteration (self , record : DBNRecord ) -> None :
277270 self ._dbn_queue .put (record )
@@ -323,7 +316,7 @@ def __init__(
323316 self ._metadata = SessionMetadata ()
324317 self ._user_gateway : str | None = user_gateway
325318 self ._user_callbacks : list [tuple [RecordCallback , ExceptionCallback | None ]] = []
326- self ._user_streams : list [tuple [ IO [ bytes ], ExceptionCallback | None ] ] = []
319+ self ._user_streams : list [ClientStream ] = []
327320 self ._user_reconnect_callbacks : list [tuple [ReconnectCallback , ExceptionCallback | None ]] = (
328321 []
329322 )
@@ -551,10 +544,11 @@ async def wait_for_close(self) -> None:
551544 def _cleanup (self ) -> None :
552545 logger .debug ("cleaning up session_id=%s" , self .session_id )
553546 self ._user_callbacks .clear ()
554- for item in self ._user_streams :
555- stream , _ = item
556- if not stream .closed :
547+ for stream in self ._user_streams :
548+ if not stream .is_closed :
557549 stream .flush ()
550+ if stream .is_managed :
551+ stream .close ()
558552
559553 self ._user_callbacks .clear ()
560554 self ._user_streams .clear ()
0 commit comments