Skip to content

Commit 9d49f57

Browse files
authored
enhance langchain callback exception (#11)
1. enhance callback exception
1 parent 8a447a2 commit 9d49f57

2 files changed

Lines changed: 60 additions & 36 deletions

File tree

cozeloop/integration/langchain/trace_callback.py

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import traceback
88
from typing import List, Dict, Union, Any, Optional
99

10-
from pydantic import Field
10+
from pydantic import Field, BaseModel
1111
from langchain.callbacks.base import BaseCallbackHandler
1212
from langchain.schema import AgentFinish, AgentAction, LLMResult
1313
from langchain_core.prompt_values import PromptValue, ChatPromptValue
@@ -93,38 +93,53 @@ async def on_llm_new_token(self, token: str, *, chunk: Optional[Union[Generation
9393

9494
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
9595
flow_span = self._get_flow_span(**kwargs)
96-
# set output span_tag
97-
flow_span.set_tags({'output': ModelTraceOutput(response.generations).to_json()})
96+
try:
97+
# set output span_tag
98+
flow_span.set_tags({'output': ModelTraceOutput(response.generations).to_json()})
99+
except Exception as e:
100+
flow_span.set_error(e)
98101
# calculate token usage,and set span_tag
99-
if response.llm_output is not None and 'token_usage' in response.llm_output:
102+
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output['token_usage']:
100103
self._set_span_tags(flow_span, response.llm_output['token_usage'], need_convert_tag_value=False)
101104
else:
102-
run_info = self.run_map[str(kwargs['run_id'])]
103-
if run_info is not None and run_info.model_meta is not None:
104-
model_name = run_info.model_meta.model_name
105-
input_messages = run_info.model_meta.message
106-
flow_span.set_input_tokens(calc_token_usage(input_messages, model_name))
107-
flow_span.set_output_tokens(calc_token_usage(response, model_name))
105+
try:
106+
run_info = self.run_map[str(kwargs['run_id'])]
107+
if run_info is not None and run_info.model_meta is not None:
108+
model_name = run_info.model_meta.model_name
109+
input_messages = run_info.model_meta.message
110+
flow_span.set_input_tokens(calc_token_usage(input_messages, model_name))
111+
flow_span.set_output_tokens(calc_token_usage(response, model_name))
112+
except Exception as e:
113+
flow_span.set_error(e)
108114
# finish flow_span
109115
flow_span.finish()
110116

111117
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
112-
if kwargs.get('run_type', '') == 'prompt' or kwargs.get('name', '') == 'ChatPromptTemplate':
113-
self._on_prompt_start(serialized, inputs, **kwargs)
114-
else:
115-
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
116-
flow_span.set_tags({'input': _convert_2_json(inputs)})
118+
flow_span = None
119+
try:
120+
if kwargs.get('run_type', '') == 'prompt' or kwargs.get('name', '') == 'ChatPromptTemplate':
121+
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
122+
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
123+
else:
124+
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
125+
flow_span.set_tags({'input': _convert_2_json(inputs)})
126+
except Exception as e:
127+
if flow_span is not None:
128+
flow_span.set_error(e)
117129

118130
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> Any:
119131
flow_span = self.run_map[str(kwargs['run_id'])].span
120-
if self.run_map[str(kwargs['run_id'])].span_type == 'prompt' and isinstance(outputs, ChatPromptValue):
121-
messages: List[Message] = []
122-
for message in outputs.messages:
123-
messages.append(Message(role=message.type, content=message.content))
124-
trace_output = PromptTraceOutput(prompts=messages)
125-
flow_span.set_tags({'output': trace_output.to_json()})
126-
else:
127-
flow_span.set_tags({'output': _convert_2_json(outputs)})
132+
try:
133+
if self.run_map[str(kwargs['run_id'])].span_type == 'prompt' and isinstance(outputs, ChatPromptValue):
134+
messages: List[Message] = []
135+
for message in outputs.messages:
136+
messages.append(Message(role=message.type, content=message.content))
137+
trace_output = PromptTraceOutput(prompts=messages)
138+
flow_span.set_tags({'output': trace_output.to_json()})
139+
else:
140+
flow_span.set_tags({'output': _convert_2_json(outputs)})
141+
except Exception as e:
142+
flow_span.set_error(e)
128143
flow_span.finish()
129144

130145
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
@@ -146,7 +161,10 @@ def on_tool_start(
146161

147162
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
148163
flow_span = self._get_flow_span(**kwargs)
149-
flow_span.set_tags({'output': _convert_2_json(output)})
164+
try:
165+
flow_span.set_tags({'output': _convert_2_json(output)})
166+
except Exception as e:
167+
flow_span.set_error(e)
150168
flow_span.finish()
151169

152170
def on_tool_error(
@@ -169,8 +187,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
169187
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
170188
return
171189

172-
def _on_prompt_start(self, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
173-
flow_span = self._new_flow_span(serialized['name'], 'prompt', **kwargs)
190+
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
174191
# get inputs
175192
params: List[Argument] = []
176193
if isinstance(inputs, str):
@@ -219,7 +236,7 @@ def _new_flow_span(self, span_name: str, span_type: str, **kwargs: Any) -> Span:
219236
span_type = _span_type_mapping(span_type)
220237
# set parent span
221238
parent_span: Span = None
222-
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None:
239+
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
223240
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
224241
# new span
225242
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
@@ -236,6 +253,8 @@ def _get_flow_span(self, **kwargs: Any) -> Span:
236253
return None
237254

238255
def _set_span_tags(self, flow_span: Span, tags: Dict[str, Any], need_convert_tag_value=True) -> None:
256+
if tags is None:
257+
return
239258
for key, value in tags.items():
240259
report_value = value
241260
if need_convert_tag_value:
@@ -369,13 +388,16 @@ def _get_model_span_tags(**kwargs: Any) -> dict:
369388

370389

371390
def _convert_2_json(inputs: Any) -> str:
372-
format_input = _convert_inputs(inputs)
373-
if isinstance(format_input, str):
374-
return format_input
375-
else:
376-
return json.dumps(format_input,
377-
default=lambda o: dict((key, value) for key, value in o.__dict__.items() if value),
378-
ensure_ascii=False)
391+
try:
392+
format_input = _convert_inputs(inputs)
393+
if isinstance(format_input, str):
394+
return format_input
395+
else:
396+
return json.dumps(format_input,
397+
default=lambda o: dict((key, value) for key, value in o.__dict__.items() if value),
398+
ensure_ascii=False)
399+
except Exception as e:
400+
return repr(e)
379401

380402

381403
def _convert_inputs(inputs: Any) -> Any:
@@ -388,7 +410,7 @@ def _convert_inputs(inputs: Any) -> Any:
388410
for key, val in inputs.items():
389411
format_inputs[key] = _convert_inputs(val)
390412
return format_inputs
391-
if isinstance(inputs, list):
413+
if isinstance(inputs, list) or isinstance(inputs, set):
392414
format_inputs = []
393415
for each in inputs:
394416
format_inputs.append(_convert_inputs(each))
@@ -422,6 +444,8 @@ def _convert_inputs(inputs: Any) -> Any:
422444
return format_inputs
423445
if isinstance(inputs, PromptValue):
424446
return _convert_inputs(inputs.to_messages())
447+
if isinstance(inputs, BaseModel):
448+
return inputs.model_dump_json()
425449
if inputs is None:
426450
return 'None'
427451
return 'type of inputs is not supported'

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cozeloop"
3-
version = "0.1.7"
3+
version = "0.1.8"
44
description = "coze loop sdk"
55
authors = ["JiangQi715 <jiangqi.rrt@bytedance.com>"]
66
license = "MIT"

0 commit comments

Comments
 (0)