1212This module contains the optional OpenAI integration for Restate.
1313"""
1414
15- import asyncio
1615import dataclasses
17- import typing
1816
1917from agents import (
20- Tool ,
2118 Usage ,
2219 Model ,
2320 RunContextWrapper ,
2421 AgentsException ,
25- Runner as OpenAIRunner ,
2622 RunConfig ,
2723 TContext ,
2824 RunResult ,
2925 Agent ,
3026 ModelBehaviorError ,
27+ ModelSettings ,
3128)
32-
3329from agents .models .multi_provider import MultiProvider
3430from agents .items import TResponseStreamEvent , TResponseOutputItem , ModelResponse
3531from agents .memory .session import SessionABC
3632from agents .items import TResponseInputItem
37- from typing import List , Any
38- from typing import AsyncIterator
39-
40- from agents .tool import FunctionTool
41- from agents .tool_context import ToolContext
33+ from agents .run import (
34+ AgentRunner ,
35+ DEFAULT_AGENT_RUNNER ,
36+ )
37+ from datetime import timedelta
38+ from typing import List , Any , AsyncIterator , Optional , cast
4239from pydantic import BaseModel
40+
4341from restate .exceptions import SdkInternalBaseException
4442from restate .extensions import current_context
45-
4643from restate import RunOptions , ObjectContext , TerminalError
4744
45+ @dataclasses .dataclass
46+ class LlmRetryOpts :
47+ max_attempts : Optional [int ] = 10
48+ """Max number of attempts (including the initial), before giving up.
49+
50+ When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
51+ max_duration : Optional [timedelta ] = None
52+ """Max duration of retries, before giving up.
53+
54+ When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
55+ initial_retry_interval : Optional [timedelta ] = timedelta (seconds = 1 )
56+ """Initial interval for the first retry attempt.
57+ Retry interval will grow by a factor specified in `retry_interval_factor`.
58+
59+ If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
60+ max_retry_interval : Optional [timedelta ] = None
61+ """Max interval between retries.
62+ Retry interval will grow by a factor specified in `retry_interval_factor`.
63+
64+ The default is 10 seconds."""
65+ retry_interval_factor : Optional [float ] = None
66+ """Exponentiation factor to use when computing the next retry delay.
67+
68+ If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
69+
4870
4971# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
5072# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
@@ -71,23 +93,23 @@ class DurableModelCalls(MultiProvider):
7193 A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
7294 """
7395
74- def __init__ (self , max_retries : int | None = 3 ):
96+ def __init__ (self , llm_retry_opts : LlmRetryOpts ):
7597 super ().__init__ ()
76- self .max_retries = max_retries
98+ self .llm_retry_opts = llm_retry_opts
7799
78100 def get_model (self , model_name : str | None ) -> Model :
79- return RestateModelWrapper (super ().get_model (model_name or None ), self .max_retries )
101+ return RestateModelWrapper (super ().get_model (model_name or None ), self .llm_retry_opts )
80102
81103
82104class RestateModelWrapper (Model ):
83105 """
84106 A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
85107 """
86108
87- def __init__ (self , model : Model , max_retries : int | None = 3 ):
109+ def __init__ (self , model : Model , llm_retry_opts : LlmRetryOpts ):
88110 self .model = model
89111 self .model_name = "RestateModelWrapper"
90- self .max_retries = max_retries
112+ self .llm_retry_opts = llm_retry_opts
91113
92114 async def get_response (self , * args , ** kwargs ) -> ModelResponse :
93115 async def call_llm () -> RestateModelResponse :
@@ -102,7 +124,18 @@ async def call_llm() -> RestateModelResponse:
102124 ctx = current_context ()
103125 if ctx is None :
104126 raise RuntimeError ("No current Restate context found, make sure to run inside a Restate handler" )
105- result = await ctx .run_typed ("call LLM" , call_llm , RunOptions (max_attempts = self .max_retries ))
127+ print ("Calling LLM with retry options:" , self .llm_retry_opts )
128+ result = await ctx .run_typed (
129+ "call LLM" ,
130+ call_llm ,
131+ RunOptions (
132+ max_attempts = self .llm_retry_opts .max_attempts ,
133+ max_duration = self .llm_retry_opts .max_duration ,
134+ initial_retry_interval = self .llm_retry_opts .initial_retry_interval ,
135+ max_retry_interval = self .llm_retry_opts .max_retry_interval ,
136+ retry_interval_factor = self .llm_retry_opts .retry_interval_factor ,
137+ ),
138+ )
106139 # convert back to original ModelResponse
107140 return ModelResponse (
108141 output = result .output ,
@@ -117,33 +150,43 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
117150class RestateSession (SessionABC ):
118151 """Restate session implementation following the Session protocol."""
119152
153+ def __init__ (self ):
154+ self ._items : List [TResponseInputItem ] | None = None
155+
120156 def _ctx (self ) -> ObjectContext :
121- return typing .cast (ObjectContext , current_context ())
157+ return cast (ObjectContext , current_context ())
158+
159+ async def _load_items_if_needed (self ) -> None :
160+ """Load items from context if not already loaded."""
161+ if self ._items is None :
162+ self ._items = await self ._ctx ().get ("items" , type_hint = List [TResponseInputItem ]) or []
122163
123164 async def get_items (self , limit : int | None = None ) -> List [TResponseInputItem ]:
124165 """Retrieve conversation history for this session."""
125- current_items = await self ._ctx (). get ( "items" , type_hint = List [ TResponseInputItem ]) or []
166+ await self ._load_items_if_needed ()
126167 if limit is not None :
127- return current_items [- limit :]
128- return current_items
168+ return self . _items [- limit :]
169+ return self . _items . copy ()
129170
130171 async def add_items (self , items : List [TResponseInputItem ]) -> None :
131172 """Store new items for this session."""
132- # Your implementation here
133- current_items = await self .get_items () or []
134- self ._ctx ().set ("items" , current_items + items )
173+ await self ._load_items_if_needed ()
174+ self ._items .extend (items )
135175
136176 async def pop_item (self ) -> TResponseInputItem | None :
137177 """Remove and return the most recent item from this session."""
138- current_items = await self .get_items () or []
139- if current_items :
140- item = current_items .pop ()
141- self ._ctx ().set ("items" , current_items )
142- return item
178+ await self ._load_items_if_needed ()
179+ if self ._items :
180+ return self ._items .pop ()
143181 return None
144182
183+ async def flush (self ) -> None :
184+ """Flush the session items to the context."""
185+ self ._ctx ().set ("items" , self ._items )
186+
145187 async def clear_session (self ) -> None :
146188 """Clear all items for this session."""
189+ self ._items = []
147190 self ._ctx ().clear ("items" )
148191
149192
@@ -189,7 +232,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio
189232 raise error
190233
191234
192- class Runner :
235+ class DurableRunner :
193236 """
194237 A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.
195238
@@ -201,9 +244,7 @@ class Runner:
201244 @staticmethod
202245 async def run (
203246 starting_agent : Agent [TContext ],
204- disable_tool_autowrapping : bool = False ,
205- * args : typing .Any ,
206- run_config : RunConfig | None = None ,
247+ input : str | list [TResponseInputItem ],
207248 ** kwargs ,
208249 ) -> RunResult :
209250 """
@@ -213,71 +254,32 @@ async def run(
213254 The result from Runner.run
214255 """
215256
216- current_run_config = run_config or RunConfig ()
217- new_run_config = dataclasses .replace (
218- current_run_config ,
219- model_provider = DurableModelCalls (),
220- )
221- restate_agent = sequentialize_and_wrap_tools (starting_agent , disable_tool_autowrapping )
222- return await OpenAIRunner .run (restate_agent , * args , run_config = new_run_config , ** kwargs )
223-
257+ # Set persisting model calls
258+ llm_retry_opts = kwargs .get ("llm_retry_opts" , LlmRetryOpts ())
259+ run_config = kwargs .pop ("run_config" , RunConfig ())
260+ run_config = dataclasses .replace (run_config , model_provider = DurableModelCalls (llm_retry_opts ))
224261
225- def sequentialize_and_wrap_tools (
226- agent : Agent [TContext ],
227- disable_tool_autowrapping : bool ,
228- ) -> Agent [TContext ]:
229- """
230- Wrap the tools of an agent to use the Restate error handling.
231-
232- Returns:
233- A new agent with wrapped tools.
234- """
235-
236- # Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237- # This lock only affects tools for this agent; handoff agents are wrapped recursively.
238- sequential_tools_lock = asyncio .Lock ()
239- wrapped_tools : list [Tool ] = []
240- for tool in agent .tools :
241- if isinstance (tool , FunctionTool ):
242-
243- def create_wrapper (captured_tool ):
244- async def on_invoke_tool_wrapper (tool_context : ToolContext [Any ], tool_input : Any ) -> Any :
245- await sequential_tools_lock .acquire ()
246-
247- async def invoke ():
248- result = await captured_tool .on_invoke_tool (tool_context , tool_input )
249- # Ensure Pydantic objects are serialized to dict for LLM compatibility
250- if hasattr (result , "model_dump" ):
251- return result .model_dump ()
252- elif hasattr (result , "dict" ):
253- return result .dict ()
254- return result
255-
256- try :
257- if disable_tool_autowrapping :
258- return await invoke ()
259-
260- ctx = current_context ()
261- if ctx is None :
262- raise RuntimeError (
263- "No current Restate context found, make sure to run inside a Restate handler"
264- )
265- return await ctx .run_typed (captured_tool .name , invoke )
266- finally :
267- sequential_tools_lock .release ()
268-
269- return on_invoke_tool_wrapper
270-
271- wrapped_tools .append (dataclasses .replace (tool , on_invoke_tool = create_wrapper (tool )))
262+ # Disable parallel tool calls
263+ model_settings = run_config .model_settings
264+ if model_settings is None :
265+ model_settings = ModelSettings (parallel_tool_calls = False )
272266 else :
273- wrapped_tools .append (tool )
267+ model_settings = dataclasses .replace (
268+ model_settings ,
269+ parallel_tool_calls = False ,
270+ )
271+ run_config = dataclasses .replace (
272+ run_config ,
273+ model_settings = model_settings ,
274+ )
274275
275- handoffs_with_wrapped_tools = []
276- for handoff in agent .handoffs :
277- # recursively wrap tools in handoff agents
278- handoffs_with_wrapped_tools .append (sequentialize_and_wrap_tools (handoff , disable_tool_autowrapping )) # type: ignore
276+ runner = DEFAULT_AGENT_RUNNER or AgentRunner ()
277+ try :
278+ result = await runner .run (starting_agent = starting_agent , input = input , run_config = run_config , ** kwargs )
279+ finally :
280+ # Flush session items to Restate
281+ session = kwargs .get ("session" , None )
282+ if session is not None and isinstance (session , RestateSession ):
283+ await session .flush ()
279284
280- return agent .clone (
281- tools = wrapped_tools ,
282- handoffs = handoffs_with_wrapped_tools ,
283- )
285+ return result
0 commit comments