@@ -357,9 +357,19 @@ def _build_reasoning_param(self, model_settings: ModelSettings) -> Any:
357357 reasoning_param = {
358358 "effort" : model_settings .reasoning .effort ,
359359 }
360- # Add generate_summary if specified and not None
361- if hasattr (model_settings .reasoning , 'generate_summary' ) and model_settings .reasoning .generate_summary is not None :
362- reasoning_param ["summary" ] = model_settings .reasoning .generate_summary
360+ # Add summary if specified (check both 'summary' and 'generate_summary' for compatibility)
361+ summary_value = None
362+ if hasattr (model_settings .reasoning , 'summary' ) and model_settings .reasoning .summary is not None :
363+ summary_value = model_settings .reasoning .summary
364+ elif (
365+ hasattr (model_settings .reasoning , 'generate_summary' )
366+ and model_settings .reasoning .generate_summary is not None
367+ ):
368+ summary_value = model_settings .reasoning .generate_summary
369+
370+ if summary_value is not None :
371+ reasoning_param ["summary" ] = summary_value
372+
363373 logger .debug (f"[TemporalStreamingModel] Using reasoning param: { reasoning_param } " )
364374 return reasoning_param
365375
@@ -842,10 +852,16 @@ def stream_response(self, *args, **kwargs):
842852class TemporalStreamingModelProvider (ModelProvider ):
843853 """Custom model provider that returns a streaming-capable model."""
844854
845- def __init__ (self ):
846- """Initialize the provider."""
855+ def __init__ (self , openai_client : Optional [AsyncOpenAI ] = None ):
856+ """Initialize the provider.
857+
858+ Args:
859+ openai_client: Optional custom AsyncOpenAI client to use for all models.
860+ If not provided, each model will create its own default client.
861+ """
847862 super ().__init__ ()
848- logger .info ("[TemporalStreamingModelProvider] Initialized" )
863+ self .openai_client = openai_client
864+ logger .info (f"[TemporalStreamingModelProvider] Initialized, custom_client={ openai_client is not None } " )
849865
850866 @override
851867 def get_model (self , model_name : Union [str , None ]) -> Model :
@@ -860,5 +876,5 @@ def get_model(self, model_name: Union[str, None]) -> Model:
860876 # Use the provided model_name or default to gpt-4o
861877 actual_model = model_name if model_name else "gpt-4o"
862878 logger .info (f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: { actual_model } " )
863- model = TemporalStreamingModel (model_name = actual_model )
879+ model = TemporalStreamingModel (model_name = actual_model , openai_client = self . openai_client )
864880 return model
0 commit comments