33
44import json
55import time
6- from typing import List , Optional , Union
6+ from typing import List , Optional , Union , Dict , Any
77from pydantic .dataclasses import dataclass
88from langchain_core .messages import BaseMessage , ToolMessage , AIMessageChunk
99from langchain_core .outputs import Generation , ChatGeneration
@@ -14,7 +14,7 @@ class ToolFunction:
1414 name : Optional [str ] = None
1515 description : Optional [str ] = None
1616 parameters : Optional [dict ] = None
17- arguments : Optional [dict ] = None
17+ arguments : Optional [Union [ dict , str ] ] = None
1818
1919
2020@dataclass
@@ -50,20 +50,31 @@ class Message:
5050 tool_calls : List [ToolCall ] = None
5151
5252 def __post_init__ (self ):
53- if self .role is not None and self .role == 'AIMessageChunk' :
53+ if self .role is not None and ( self .role == 'AIMessageChunk' or self . role == 'ai' ) :
5454 self .role = 'assistant'
5555 parts : Optional [List [Parts ]] = []
5656 if isinstance (self .content , List ) and all (isinstance (x , dict ) for x in self .content ):
57+ is_parts = False
5758 for each in self .content :
5859 text = each .get ('text' , None )
5960 url = each .get ('url' , each .get ('image_url' , {}).get ('url' , None ))
61+ if text is None and url is None :
62+ continue
63+ is_parts = True
6064 parts .append (Parts (type = each .get ('type' , '' ), text = text , image_url = ImageUrl (url = url ) if url is not None else None ))
61- self .content = None
65+ if is_parts :
66+ self .content = None
67+ else :
68+ self .content = self .content .__str__ ()
6269 elif isinstance (self .content , dict ):
70+ is_part = False
6371 text = self .content .get ('text' , None )
6472 url = self .content .get ('url' , self .content .get ('image_url' , {}).get ('url' , None ))
65- parts .append (Parts (type = self .content .get ('type' , '' ), text = text , image_url = ImageUrl (url = url ) if url is not None else None ))
66- self .content = None
73+ if text is not None or url is not None :
74+ parts .append (Parts (type = self .content .get ('type' , '' ), text = text , image_url = ImageUrl (url = url ) if url is not None else None ))
75+ self .content = None
76+ else :
77+ self .content = self .content .__str__ ()
6778 elif isinstance (self .content , List ) and all (type (x , Parts ) for x in self .content ):
6879 parts = self .content
6980 self .content = None
@@ -172,4 +183,4 @@ def convert_tool_calls(tool_calls: list) -> List[ToolCall]:
172183 for tool_call in tool_calls :
173184 function = ToolFunction (name = tool_call .get ('function' , {}).get ('name' , '' ), arguments = json .loads (tool_call .get ('function' , {}).get ('arguments' , {})))
174185 format_tool_calls .append (ToolCall (id = tool_call .get ('id' , '' ), type = tool_call .get ('type' , '' ), function = function ))
175- return format_tool_calls
186+ return format_tool_calls
0 commit comments