Skip to content

Commit 35e5f47

Browse files
committed
enhance tool_calls
1 parent d0345e2 commit 35e5f47

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

cozeloop/integration/langchain/trace_callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from langchain.callbacks.base import BaseCallbackHandler
1313
from langchain.schema import AgentFinish, AgentAction, LLMResult
1414
from langchain_core.prompt_values import PromptValue, ChatPromptValue
15-
from langchain_core.messages import BaseMessage, AIMessageChunk
15+
from langchain_core.messages import BaseMessage, AIMessageChunk, AIMessage
1616
from langchain_core.prompts import AIMessagePromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
1717
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
1818

@@ -416,7 +416,7 @@ def _convert_inputs(inputs: Any) -> Any:
416416
for each in inputs:
417417
format_inputs.append(_convert_inputs(each))
418418
return format_inputs
419-
if isinstance(inputs, AIMessageChunk):
419+
if isinstance(inputs, (AIMessageChunk, AIMessage)):
420420
"""
421421
Must be before BaseMessage.
422422
"""

cozeloop/integration/langchain/trace_model/llm_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: MIT
33

44
import json
5+
import logging
56
import time
67
from typing import List, Optional, Union, Dict, Any
78
from pydantic.dataclasses import dataclass
@@ -208,6 +209,13 @@ def convert_tool_calls_by_raw(tool_calls: list) -> List[ToolCall]:
208209
def convert_tool_calls_by_additional_kwargs(tool_calls: list) -> List[ToolCall]:
209210
format_tool_calls: List[ToolCall] = []
210211
for tool_call in tool_calls:
211-
function = ToolFunction(name=tool_call.get('function', {}).get('name', ''), arguments=json.loads(tool_call.get('function', {}).get('arguments', '{}')))
212+
raw_args = tool_call.get('function', {}).get('arguments', '{}')
213+
final_args = None
214+
try:
215+
final_args = json.loads(raw_args)
216+
except Exception as e:
217+
final_args = raw_args
218+
logging.error(f"convert_tool_calls_by_additional_kwargs failed, error: {e}, tool_call.function.arguments: {raw_args}")
219+
function = ToolFunction(name=tool_call.get('function', {}).get('name', ''), arguments=final_args)
212220
format_tool_calls.append(ToolCall(id=tool_call.get('id', ''), type=tool_call.get('type', ''), function=function))
213221
return format_tool_calls

0 commit comments

Comments
 (0)