Skip to content

Commit c610405

Browse files
committed
Update OpenAI integration: flush session at end of runner, allow setting model retry opts, set parallel_tool_calls to false
1 parent a95e3ac commit c610405

2 files changed

Lines changed: 6 additions & 9 deletions

File tree

python/restate/ext/openai/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from .runner_wrapper import (
1818
DurableRunner,
19-
DurableModelCalls,
2019
continue_on_terminal_errors,
2120
raise_terminal_errors,
2221
RestateSession,
@@ -43,12 +42,11 @@ def restate_context() -> Context:
4342

4443

4544
__all__ = [
46-
"DurableModelCalls",
47-
"continue_on_terminal_errors",
48-
"raise_terminal_errors",
49-
"RestateSession",
5045
"DurableRunner",
46+
"RestateSession",
5147
"LlmRetryOpts",
5248
"restate_object_context",
5349
"restate_context",
50+
"continue_on_terminal_errors",
51+
"raise_terminal_errors",
5452
]

python/restate/ext/openai/runner_wrapper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class DurableModelCalls(MultiProvider):
9393
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
9494
"""
9595

96-
def __init__(self, llm_retry_opts: LlmRetryOpts):
96+
def __init__(self, llm_retry_opts: LlmRetryOpts | None = None):
9797
super().__init__()
9898
self.llm_retry_opts = llm_retry_opts
9999

@@ -106,7 +106,7 @@ class RestateModelWrapper(Model):
106106
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
107107
"""
108108

109-
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts):
109+
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = LlmRetryOpts()):
110110
self.model = model
111111
self.model_name = "RestateModelWrapper"
112112
self.llm_retry_opts = llm_retry_opts
@@ -124,7 +124,6 @@ async def call_llm() -> RestateModelResponse:
124124
ctx = current_context()
125125
if ctx is None:
126126
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
127-
print("Calling LLM with retry options:", self.llm_retry_opts)
128127
result = await ctx.run_typed(
129128
"call LLM",
130129
call_llm,
@@ -255,7 +254,7 @@ async def run(
255254
"""
256255

257256
# Set persisting model calls
258-
llm_retry_opts = kwargs.get("llm_retry_opts", LlmRetryOpts())
257+
llm_retry_opts = kwargs.get("llm_retry_opts", None)
259258
run_config = kwargs.pop("run_config", RunConfig())
260259
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts))
261260

0 commit comments

Comments
 (0)