Skip to content

Commit a95e3ac

Browse files
committed
Update OpenAI integration: flush session at end of runner, allow setting model retry opts, set parallel_tool_calls to false
1 parent a4765b0 commit a95e3ac

2 files changed

Lines changed: 133 additions & 99 deletions

File tree

python/restate/ext/openai/__init__.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,43 @@
1212
This module contains the optional OpenAI integration for Restate.
1313
"""
1414

15-
from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors
15+
import typing
16+
17+
from .runner_wrapper import (
18+
DurableRunner,
19+
DurableModelCalls,
20+
continue_on_terminal_errors,
21+
raise_terminal_errors,
22+
RestateSession,
23+
LlmRetryOpts
24+
)
25+
from restate import ObjectContext, Context
26+
from restate.server_context import current_context
27+
28+
29+
def restate_object_context() -> ObjectContext:
30+
"""Get the current Restate ObjectContext."""
31+
ctx = current_context()
32+
if ctx is None:
33+
raise RuntimeError("No Restate context found.")
34+
return typing.cast(ObjectContext, ctx)
35+
36+
37+
def restate_context() -> Context:
38+
"""Get the current Restate Context."""
39+
ctx = current_context()
40+
if ctx is None:
41+
raise RuntimeError("No Restate context found.")
42+
return ctx
43+
1644

1745
__all__ = [
1846
"DurableModelCalls",
1947
"continue_on_terminal_errors",
2048
"raise_terminal_errors",
21-
"Runner",
49+
"RestateSession",
50+
"DurableRunner",
51+
"LlmRetryOpts",
52+
"restate_object_context",
53+
"restate_context",
2254
]

python/restate/ext/openai/runner_wrapper.py

Lines changed: 99 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,61 @@
1212
This module contains the optional OpenAI integration for Restate.
1313
"""
1414

15-
import asyncio
1615
import dataclasses
17-
import typing
1816

1917
from agents import (
20-
Tool,
2118
Usage,
2219
Model,
2320
RunContextWrapper,
2421
AgentsException,
25-
Runner as OpenAIRunner,
2622
RunConfig,
2723
TContext,
2824
RunResult,
2925
Agent,
3026
ModelBehaviorError,
27+
ModelSettings,
3128
)
32-
3329
from agents.models.multi_provider import MultiProvider
3430
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
3531
from agents.memory.session import SessionABC
3632
from agents.items import TResponseInputItem
37-
from typing import List, Any
38-
from typing import AsyncIterator
39-
40-
from agents.tool import FunctionTool
41-
from agents.tool_context import ToolContext
33+
from agents.run import (
34+
AgentRunner,
35+
DEFAULT_AGENT_RUNNER,
36+
)
37+
from datetime import timedelta
38+
from typing import List, Any, AsyncIterator, Optional, cast
4239
from pydantic import BaseModel
40+
4341
from restate.exceptions import SdkInternalBaseException
4442
from restate.extensions import current_context
45-
4643
from restate import RunOptions, ObjectContext, TerminalError
4744

45+
@dataclasses.dataclass
46+
class LlmRetryOpts:
47+
max_attempts: Optional[int] = 10
48+
"""Max number of attempts (including the initial), before giving up.
49+
50+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
51+
max_duration: Optional[timedelta] = None
52+
"""Max duration of retries, before giving up.
53+
54+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
55+
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
56+
"""Initial interval for the first retry attempt.
57+
Retry interval will grow by a factor specified in `retry_interval_factor`.
58+
59+
If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
60+
max_retry_interval: Optional[timedelta] = None
61+
"""Max interval between retries.
62+
Retry interval will grow by a factor specified in `retry_interval_factor`.
63+
64+
The default is 10 seconds."""
65+
retry_interval_factor: Optional[float] = None
66+
"""Exponentiation factor to use when computing the next retry delay.
67+
68+
If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
69+
4870

4971
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
5072
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
@@ -71,23 +93,23 @@ class DurableModelCalls(MultiProvider):
7193
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
7294
"""
7395

74-
def __init__(self, max_retries: int | None = 3):
96+
def __init__(self, llm_retry_opts: LlmRetryOpts):
7597
super().__init__()
76-
self.max_retries = max_retries
98+
self.llm_retry_opts = llm_retry_opts
7799

78100
def get_model(self, model_name: str | None) -> Model:
79-
return RestateModelWrapper(super().get_model(model_name or None), self.max_retries)
101+
return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts)
80102

81103

82104
class RestateModelWrapper(Model):
83105
"""
84106
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
85107
"""
86108

87-
def __init__(self, model: Model, max_retries: int | None = 3):
109+
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts):
88110
self.model = model
89111
self.model_name = "RestateModelWrapper"
90-
self.max_retries = max_retries
112+
self.llm_retry_opts = llm_retry_opts
91113

92114
async def get_response(self, *args, **kwargs) -> ModelResponse:
93115
async def call_llm() -> RestateModelResponse:
@@ -102,7 +124,18 @@ async def call_llm() -> RestateModelResponse:
102124
ctx = current_context()
103125
if ctx is None:
104126
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
105-
result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries))
127+
print("Calling LLM with retry options:", self.llm_retry_opts)
128+
result = await ctx.run_typed(
129+
"call LLM",
130+
call_llm,
131+
RunOptions(
132+
max_attempts=self.llm_retry_opts.max_attempts,
133+
max_duration=self.llm_retry_opts.max_duration,
134+
initial_retry_interval=self.llm_retry_opts.initial_retry_interval,
135+
max_retry_interval=self.llm_retry_opts.max_retry_interval,
136+
retry_interval_factor=self.llm_retry_opts.retry_interval_factor,
137+
),
138+
)
106139
# convert back to original ModelResponse
107140
return ModelResponse(
108141
output=result.output,
@@ -117,33 +150,43 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
117150
class RestateSession(SessionABC):
118151
"""Restate session implementation following the Session protocol."""
119152

153+
def __init__(self):
154+
self._items: List[TResponseInputItem] | None = None
155+
120156
def _ctx(self) -> ObjectContext:
121-
return typing.cast(ObjectContext, current_context())
157+
return cast(ObjectContext, current_context())
158+
159+
async def _load_items_if_needed(self) -> None:
160+
"""Load items from context if not already loaded."""
161+
if self._items is None:
162+
self._items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
122163

123164
async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
124165
"""Retrieve conversation history for this session."""
125-
current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
166+
await self._load_items_if_needed()
126167
if limit is not None:
127-
return current_items[-limit:]
128-
return current_items
168+
return self._items[-limit:]
169+
return self._items.copy()
129170

130171
async def add_items(self, items: List[TResponseInputItem]) -> None:
131172
"""Store new items for this session."""
132-
# Your implementation here
133-
current_items = await self.get_items() or []
134-
self._ctx().set("items", current_items + items)
173+
await self._load_items_if_needed()
174+
self._items.extend(items)
135175

136176
async def pop_item(self) -> TResponseInputItem | None:
137177
"""Remove and return the most recent item from this session."""
138-
current_items = await self.get_items() or []
139-
if current_items:
140-
item = current_items.pop()
141-
self._ctx().set("items", current_items)
142-
return item
178+
await self._load_items_if_needed()
179+
if self._items:
180+
return self._items.pop()
143181
return None
144182

183+
async def flush(self) -> None:
184+
"""Flush the session items to the context."""
185+
self._ctx().set("items", self._items)
186+
145187
async def clear_session(self) -> None:
146188
"""Clear all items for this session."""
189+
self._items = []
147190
self._ctx().clear("items")
148191

149192

@@ -189,7 +232,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio
189232
raise error
190233

191234

192-
class Runner:
235+
class DurableRunner:
193236
"""
194237
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.
195238
@@ -201,9 +244,7 @@ class Runner:
201244
@staticmethod
202245
async def run(
203246
starting_agent: Agent[TContext],
204-
disable_tool_autowrapping: bool = False,
205-
*args: typing.Any,
206-
run_config: RunConfig | None = None,
247+
input: str | list[TResponseInputItem],
207248
**kwargs,
208249
) -> RunResult:
209250
"""
@@ -213,71 +254,32 @@ async def run(
213254
The result from Runner.run
214255
"""
215256

216-
current_run_config = run_config or RunConfig()
217-
new_run_config = dataclasses.replace(
218-
current_run_config,
219-
model_provider=DurableModelCalls(),
220-
)
221-
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping)
222-
return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs)
223-
257+
# Set persisting model calls
258+
llm_retry_opts = kwargs.get("llm_retry_opts", LlmRetryOpts())
259+
run_config = kwargs.pop("run_config", RunConfig())
260+
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts))
224261

225-
def sequentialize_and_wrap_tools(
226-
agent: Agent[TContext],
227-
disable_tool_autowrapping: bool,
228-
) -> Agent[TContext]:
229-
"""
230-
Wrap the tools of an agent to use the Restate error handling.
231-
232-
Returns:
233-
A new agent with wrapped tools.
234-
"""
235-
236-
# Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237-
# This lock only affects tools for this agent; handoff agents are wrapped recursively.
238-
sequential_tools_lock = asyncio.Lock()
239-
wrapped_tools: list[Tool] = []
240-
for tool in agent.tools:
241-
if isinstance(tool, FunctionTool):
242-
243-
def create_wrapper(captured_tool):
244-
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
245-
await sequential_tools_lock.acquire()
246-
247-
async def invoke():
248-
result = await captured_tool.on_invoke_tool(tool_context, tool_input)
249-
# Ensure Pydantic objects are serialized to dict for LLM compatibility
250-
if hasattr(result, "model_dump"):
251-
return result.model_dump()
252-
elif hasattr(result, "dict"):
253-
return result.dict()
254-
return result
255-
256-
try:
257-
if disable_tool_autowrapping:
258-
return await invoke()
259-
260-
ctx = current_context()
261-
if ctx is None:
262-
raise RuntimeError(
263-
"No current Restate context found, make sure to run inside a Restate handler"
264-
)
265-
return await ctx.run_typed(captured_tool.name, invoke)
266-
finally:
267-
sequential_tools_lock.release()
268-
269-
return on_invoke_tool_wrapper
270-
271-
wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool)))
262+
# Disable parallel tool calls
263+
model_settings = run_config.model_settings
264+
if model_settings is None:
265+
model_settings = ModelSettings(parallel_tool_calls=False)
272266
else:
273-
wrapped_tools.append(tool)
267+
model_settings = dataclasses.replace(
268+
model_settings,
269+
parallel_tool_calls=False,
270+
)
271+
run_config = dataclasses.replace(
272+
run_config,
273+
model_settings=model_settings,
274+
)
274275

275-
handoffs_with_wrapped_tools = []
276-
for handoff in agent.handoffs:
277-
# recursively wrap tools in handoff agents
278-
handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore
276+
runner = DEFAULT_AGENT_RUNNER or AgentRunner()
277+
try:
278+
result = await runner.run(starting_agent=starting_agent, input=input, run_config=run_config, **kwargs)
279+
finally:
280+
# Flush session items to Restate
281+
session = kwargs.get("session", None)
282+
if session is not None and isinstance(session, RestateSession):
283+
await session.flush()
279284

280-
return agent.clone(
281-
tools=wrapped_tools,
282-
handoffs=handoffs_with_wrapped_tools,
283-
)
285+
return result

0 commit comments

Comments
 (0)