Skip to content

Commit 98cd7bc

Browse files
committed
Bug fixes streaming provider
1 parent 5ad8b88 commit 98cd7bc

1 file changed

Lines changed: 23 additions & 7 deletions

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,19 @@ def _build_reasoning_param(self, model_settings: ModelSettings) -> Any:
357357
reasoning_param = {
358358
"effort": model_settings.reasoning.effort,
359359
}
360-
# Add generate_summary if specified and not None
361-
if hasattr(model_settings.reasoning, 'generate_summary') and model_settings.reasoning.generate_summary is not None:
362-
reasoning_param["summary"] = model_settings.reasoning.generate_summary
360+
# Add summary if specified (check both 'summary' and 'generate_summary' for compatibility)
361+
summary_value = None
362+
if hasattr(model_settings.reasoning, 'summary') and model_settings.reasoning.summary is not None:
363+
summary_value = model_settings.reasoning.summary
364+
elif (
365+
hasattr(model_settings.reasoning, 'generate_summary')
366+
and model_settings.reasoning.generate_summary is not None
367+
):
368+
summary_value = model_settings.reasoning.generate_summary
369+
370+
if summary_value is not None:
371+
reasoning_param["summary"] = summary_value
372+
363373
logger.debug(f"[TemporalStreamingModel] Using reasoning param: {reasoning_param}")
364374
return reasoning_param
365375

@@ -842,10 +852,16 @@ def stream_response(self, *args, **kwargs):
842852
class TemporalStreamingModelProvider(ModelProvider):
843853
"""Custom model provider that returns a streaming-capable model."""
844854

845-
def __init__(self):
846-
"""Initialize the provider."""
855+
def __init__(self, openai_client: Optional[AsyncOpenAI] = None):
856+
"""Initialize the provider.
857+
858+
Args:
859+
openai_client: Optional custom AsyncOpenAI client to use for all models.
860+
If not provided, each model will create its own default client.
861+
"""
847862
super().__init__()
848-
logger.info("[TemporalStreamingModelProvider] Initialized")
863+
self.openai_client = openai_client
864+
logger.info(f"[TemporalStreamingModelProvider] Initialized, custom_client={openai_client is not None}")
849865

850866
@override
851867
def get_model(self, model_name: Union[str, None]) -> Model:
@@ -860,5 +876,5 @@ def get_model(self, model_name: Union[str, None]) -> Model:
860876
# Use the provided model_name or default to gpt-4o
861877
actual_model = model_name if model_name else "gpt-4o"
862878
logger.info(f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: {actual_model}")
863-
model = TemporalStreamingModel(model_name=actual_model)
879+
model = TemporalStreamingModel(model_name=actual_model, openai_client=self.openai_client)
864880
return model

0 commit comments

Comments
 (0)