Skip to content

Commit 3026733

Browse files
authored
Use single LC middleware for SDK middlewares (#120)
1 parent 460d0e5 commit 3026733

1 file changed

Lines changed: 70 additions & 48 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -194,18 +194,16 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
194194
middleware.extend(after_user_middlewares)
195195

196196
model_impl = _create_langchain_model(agent.model)
197-
lc_middleware: list[LC_AgentMiddleware] = [
198-
_Middleware(m, model_impl, agent.logger) for m in (middleware or [])
199-
]
197+
198+
lc_middleware: list[LC_AgentMiddleware] = [_Middleware(middleware, model_impl)]
200199

201200
# This middleware is executed just after the tool execution and populates
202201
# the artifact field for failed tool calls, since in such cases we can't
203202
# populate the artifact in LC directly since this is an LC_ToolException that only
204203
# allows setting of the content field.
205204
# We do that here, to avoid doing this logic in the individual conversion helpers.
206205
#
207-
# TODO: once we move middlewares into one LC middleware, we should move
208-
# that piece of logic there (DVPL-12959).
206+
# TODO: we could move this logic to _Middleware.
209207
class _ToolFailureArtifact(LC_AgentMiddleware):
210208
@override
211209
async def awrap_tool_call(
@@ -338,8 +336,7 @@ class _SubagentArgumentPacker(LC_AgentMiddleware):
338336
# This middleware performs the corresponding pack/unpack at the two
339337
# points in the LangChain call graph where raw args are needed/retreived.
340338
#
341-
# TODO: once we move middlewares into one LC middleware, we should move
342-
# that piece of logic there (DVPL-12959).
339+
# TODO: we could move this logic to _Middleware.
343340
@override
344341
async def awrap_model_call(
345342
self,
@@ -587,47 +584,78 @@ def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
587584

588585

589586
class _Middleware(LC_AgentMiddleware):
590-
_middleware: AgentMiddleware
587+
_middleware: list[AgentMiddleware]
591588
_model: BaseChatModel
592-
_logger: logging.Logger
593-
_name: str
594589

595-
def __init__(
596-
self,
597-
middleware: AgentMiddleware,
598-
model: BaseChatModel,
599-
logger: logging.Logger,
600-
) -> None:
590+
def __init__(self, middleware: list[AgentMiddleware], model: BaseChatModel) -> None:
601591
self._middleware = middleware
602592
self._model = model
603-
self._logger = logger
604-
self._name = str(uuid.uuid4())
605593

606-
def _is_overridden(self, method_name: str) -> bool:
607-
"""Return True if the middleware method was overridden by the user."""
608-
return getattr(type(self._middleware), method_name) is not getattr(
609-
AgentMiddleware, method_name
610-
)
594+
def _with_model_middleware(
595+
self, model_invoke: ModelMiddlewareHandler
596+
) -> Callable[[ModelRequest], Awaitable[ModelResponse]]:
597+
invoke = model_invoke
598+
for middleware in reversed(self._middleware or []):
611599

612-
@property
613-
@override
614-
def name(self) -> str:
615-
return self._name
600+
def make_next(
601+
m: AgentMiddleware, h: ModelMiddlewareHandler
602+
) -> ModelMiddlewareHandler:
603+
async def next(r: ModelRequest) -> ModelResponse:
604+
return await m.model_middleware(r, h)
605+
606+
return next
607+
608+
invoke = make_next(middleware, invoke)
609+
610+
return invoke
611+
612+
def _with_tool_call_middleware(
613+
self, tool_invoke: ToolMiddlewareHandler
614+
) -> Callable[[ToolRequest], Awaitable[ToolResponse]]:
615+
invoke = tool_invoke
616+
for middleware in reversed(self._middleware or []):
617+
618+
def make_next(
619+
m: AgentMiddleware, h: ToolMiddlewareHandler
620+
) -> ToolMiddlewareHandler:
621+
async def next(r: ToolRequest) -> ToolResponse:
622+
return await m.tool_middleware(r, h)
623+
624+
return next
625+
626+
invoke = make_next(middleware, invoke)
627+
628+
return invoke
629+
630+
def _with_subagent_call_middleware(
631+
self, subagent_invoke: SubagentMiddlewareHandler
632+
) -> Callable[[SubagentRequest], Awaitable[SubagentResponse]]:
633+
invoke = subagent_invoke
634+
for middleware in reversed(self._middleware or []):
635+
636+
def make_next(
637+
m: AgentMiddleware, h: SubagentMiddlewareHandler
638+
) -> SubagentMiddlewareHandler:
639+
async def next(r: SubagentRequest) -> SubagentResponse:
640+
return await m.subagent_middleware(r, h)
641+
642+
return next
643+
644+
invoke = make_next(middleware, invoke)
645+
646+
return invoke
616647

617648
@override
618649
async def awrap_model_call(
619650
self,
620651
request: LC_ModelRequest,
621652
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
622653
) -> LC_ModelCallResult:
623-
if not self._is_overridden("model_middleware"):
624-
# Optimization: if not overridden, then skip the conversion overhead.
625-
return await handler(request)
626-
627-
sdk_response = await self._middleware.model_middleware(
628-
_convert_model_request_from_lc(request, self._model),
629-
_convert_model_handler_from_lc(handler, original_request=request),
654+
req = _convert_model_request_from_lc(request, self._model)
655+
final_handler = _convert_model_handler_from_lc(
656+
handler, original_request=request
630657
)
658+
sdk_response = await self._with_model_middleware(final_handler)(req)
631659
return _convert_model_response_to_model_result(sdk_response)
632660

633661
@override
@@ -641,14 +669,11 @@ async def awrap_tool_call(
641669
call = _map_tool_call_from_langchain(request.tool_call)
642670

643671
if isinstance(call, ToolCall):
644-
if not self._is_overridden("tool_middleware"):
645-
# Optimization: if not overridden, skip the conversion overhead.
646-
return await handler(request)
647-
648-
sdk_response = await self._middleware.tool_middleware(
649-
_convert_tool_request_from_lc(request, self._model),
650-
_convert_tool_handler_from_lc(handler, original_request=request),
672+
req = _convert_tool_request_from_lc(request, self._model)
673+
final_handler = _convert_tool_handler_from_lc(
674+
handler, original_request=request
651675
)
676+
sdk_response = await self._with_tool_call_middleware(final_handler)(req)
652677

653678
sdk_result = sdk_response.result
654679
match sdk_result:
@@ -672,14 +697,11 @@ async def awrap_tool_call(
672697
artifact=sdk_result,
673698
)
674699

675-
if not self._is_overridden("subagent_middleware"):
676-
# Optimization: if not overridden, skip the conversion overhead.
677-
return await handler(request)
678-
679-
sdk_response = await self._middleware.subagent_middleware(
680-
_convert_subagent_request_from_lc(request, self._model),
681-
_convert_subagent_handler_from_lc(handler, original_request=request),
700+
req = _convert_subagent_request_from_lc(request, self._model)
701+
final_handler = _convert_subagent_handler_from_lc(
702+
handler, original_request=request
682703
)
704+
sdk_response = await self._with_subagent_call_middleware(final_handler)(req)
683705

684706
sdk_result = sdk_response.result
685707
match sdk_result:

0 commit comments

Comments
 (0)