@@ -194,18 +194,16 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
194194 middleware .extend (after_user_middlewares )
195195
196196 model_impl = _create_langchain_model (agent .model )
197- lc_middleware : list [LC_AgentMiddleware ] = [
198- _Middleware (m , model_impl , agent .logger ) for m in (middleware or [])
199- ]
197+
198+ lc_middleware : list [LC_AgentMiddleware ] = [_Middleware (middleware , model_impl )]
200199
201200 # This middleware is executed just after the tool execution and populates
202201 # the artifact field for failed tool calls, since in such cases we can't
203202 # populate the artifact in LC directly since this is an LC_ToolException that only
204203 # allows setting of the content field.
205204 # We do that here, to avoid doing this logic in the individual conversion helpers.
206205 #
207- # TODO: once we move middlewares into one LC middleware, we should move
208- # that piece of logic there (DVPL-12959).
206+ # TODO: we could move this logic to _Middleware.
209207 class _ToolFailureArtifact (LC_AgentMiddleware ):
210208 @override
211209 async def awrap_tool_call (
@@ -338,8 +336,7 @@ class _SubagentArgumentPacker(LC_AgentMiddleware):
338336 # This middleware performs the corresponding pack/unpack at the two
339337 # points in the LangChain call graph where raw args are needed/retreived.
340338 #
341- # TODO: once we move middlewares into one LC middleware, we should move
342- # that piece of logic there (DVPL-12959).
339+ # TODO: we could move this logic to _Middleware.
343340 @override
344341 async def awrap_model_call (
345342 self ,
@@ -587,47 +584,78 @@ def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
587584
588585
589586class _Middleware (LC_AgentMiddleware ):
590- _middleware : AgentMiddleware
587+ _middleware : list [ AgentMiddleware ]
591588 _model : BaseChatModel
592- _logger : logging .Logger
593- _name : str
594589
595- def __init__ (
596- self ,
597- middleware : AgentMiddleware ,
598- model : BaseChatModel ,
599- logger : logging .Logger ,
600- ) -> None :
590+ def __init__ (self , middleware : list [AgentMiddleware ], model : BaseChatModel ) -> None :
601591 self ._middleware = middleware
602592 self ._model = model
603- self ._logger = logger
604- self ._name = str (uuid .uuid4 ())
605593
606- def _is_overridden ( self , method_name : str ) -> bool :
607- """Return True if the middleware method was overridden by the user."""
608- return getattr ( type ( self . _middleware ), method_name ) is not getattr (
609- AgentMiddleware , method_name
610- )
594+ def _with_model_middleware (
595+ self , model_invoke : ModelMiddlewareHandler
596+ ) -> Callable [[ ModelRequest ], Awaitable [ ModelResponse ]]:
597+ invoke = model_invoke
598+ for middleware in reversed ( self . _middleware or []):
611599
612- @property
613- @override
614- def name (self ) -> str :
615- return self ._name
600+ def make_next (
601+ m : AgentMiddleware , h : ModelMiddlewareHandler
602+ ) -> ModelMiddlewareHandler :
603+ async def next (r : ModelRequest ) -> ModelResponse :
604+ return await m .model_middleware (r , h )
605+
606+ return next
607+
608+ invoke = make_next (middleware , invoke )
609+
610+ return invoke
611+
612+ def _with_tool_call_middleware (
613+ self , tool_invoke : ToolMiddlewareHandler
614+ ) -> Callable [[ToolRequest ], Awaitable [ToolResponse ]]:
615+ invoke = tool_invoke
616+ for middleware in reversed (self ._middleware or []):
617+
618+ def make_next (
619+ m : AgentMiddleware , h : ToolMiddlewareHandler
620+ ) -> ToolMiddlewareHandler :
621+ async def next (r : ToolRequest ) -> ToolResponse :
622+ return await m .tool_middleware (r , h )
623+
624+ return next
625+
626+ invoke = make_next (middleware , invoke )
627+
628+ return invoke
629+
630+ def _with_subagent_call_middleware (
631+ self , subagent_invoke : SubagentMiddlewareHandler
632+ ) -> Callable [[SubagentRequest ], Awaitable [SubagentResponse ]]:
633+ invoke = subagent_invoke
634+ for middleware in reversed (self ._middleware or []):
635+
636+ def make_next (
637+ m : AgentMiddleware , h : SubagentMiddlewareHandler
638+ ) -> SubagentMiddlewareHandler :
639+ async def next (r : SubagentRequest ) -> SubagentResponse :
640+ return await m .subagent_middleware (r , h )
641+
642+ return next
643+
644+ invoke = make_next (middleware , invoke )
645+
646+ return invoke
616647
617648 @override
618649 async def awrap_model_call (
619650 self ,
620651 request : LC_ModelRequest ,
621652 handler : Callable [[LC_ModelRequest ], Awaitable [LC_ModelCallResult ]],
622653 ) -> LC_ModelCallResult :
623- if not self ._is_overridden ("model_middleware" ):
624- # Optimization: if not overridden, then skip the conversion overhead.
625- return await handler (request )
626-
627- sdk_response = await self ._middleware .model_middleware (
628- _convert_model_request_from_lc (request , self ._model ),
629- _convert_model_handler_from_lc (handler , original_request = request ),
654+ req = _convert_model_request_from_lc (request , self ._model )
655+ final_handler = _convert_model_handler_from_lc (
656+ handler , original_request = request
630657 )
658+ sdk_response = await self ._with_model_middleware (final_handler )(req )
631659 return _convert_model_response_to_model_result (sdk_response )
632660
633661 @override
@@ -641,14 +669,11 @@ async def awrap_tool_call(
641669 call = _map_tool_call_from_langchain (request .tool_call )
642670
643671 if isinstance (call , ToolCall ):
644- if not self ._is_overridden ("tool_middleware" ):
645- # Optimization: if not overridden, skip the conversion overhead.
646- return await handler (request )
647-
648- sdk_response = await self ._middleware .tool_middleware (
649- _convert_tool_request_from_lc (request , self ._model ),
650- _convert_tool_handler_from_lc (handler , original_request = request ),
672+ req = _convert_tool_request_from_lc (request , self ._model )
673+ final_handler = _convert_tool_handler_from_lc (
674+ handler , original_request = request
651675 )
676+ sdk_response = await self ._with_tool_call_middleware (final_handler )(req )
652677
653678 sdk_result = sdk_response .result
654679 match sdk_result :
@@ -672,14 +697,11 @@ async def awrap_tool_call(
672697 artifact = sdk_result ,
673698 )
674699
675- if not self ._is_overridden ("subagent_middleware" ):
676- # Optimization: if not overridden, skip the conversion overhead.
677- return await handler (request )
678-
679- sdk_response = await self ._middleware .subagent_middleware (
680- _convert_subagent_request_from_lc (request , self ._model ),
681- _convert_subagent_handler_from_lc (handler , original_request = request ),
700+ req = _convert_subagent_request_from_lc (request , self ._model )
701+ final_handler = _convert_subagent_handler_from_lc (
702+ handler , original_request = request
682703 )
704+ sdk_response = await self ._with_subagent_call_middleware (final_handler )(req )
683705
684706 sdk_result = sdk_response .result
685707 match sdk_result :
0 commit comments