55import time
66from typing import List , Optional , Union , Dict , Any
77from pydantic .dataclasses import dataclass
8- from langchain_core .messages import BaseMessage , ToolMessage , AIMessageChunk
8+ from langchain_core .messages import BaseMessage , ToolMessage , AIMessageChunk , AIMessage
99from langchain_core .outputs import Generation , ChatGeneration
1010
1111
@@ -122,11 +122,30 @@ def __init__(self, messages: List[Union[BaseMessage, List[BaseMessage]]], invoca
122122 elif isinstance (inner_messages , List ):
123123 for message in inner_messages :
124124 process_messages .append (message )
125+
126+ tool_call_id_name_map = {}
127+ for message in process_messages :
128+ if isinstance (message , (AIMessageChunk , AIMessage )):
129+ for tool_call in message .additional_kwargs .get ('tool_calls' , []):
130+ if tool_call .get ('id' , '' ):
131+ tool_call_id_name_map [tool_call .get ('id' , '' )] = tool_call .get ('function' , {}).get ('name' , '' )
132+ for tool_call in message .tool_calls :
133+ if tool_call .get ('id' , '' ):
134+ tool_call_id_name_map [tool_call .get ('id' , '' )] = tool_call .get ('name' , '' )
135+
125136 for message in process_messages :
126- if isinstance (message , AIMessageChunk ):
127- self ._messages .append (Message (role = message .type , content = message .content , tool_calls = convert_tool_calls (message .additional_kwargs .get ('tool_calls' , []))))
137+ if isinstance (message , (AIMessageChunk , AIMessage )):
138+ tool_calls = convert_tool_calls_by_additional_kwargs (message .additional_kwargs .get ('tool_calls' , []))
139+ if len (tool_calls ) == 0 :
140+ tool_calls = convert_tool_calls_by_raw (message .tool_calls )
141+ self ._messages .append (Message (role = message .type , content = message .content , tool_calls = tool_calls ))
128142 elif isinstance (message , ToolMessage ):
129- tool_call = ToolCall (id = message .tool_call_id , type = message .type , function = ToolFunction (name = message .additional_kwargs .get ('name' , '' )))
143+ name = ''
144+ if tool_call_id_name_map .get (message .tool_call_id , None ) is not None :
145+ name = tool_call_id_name_map [message .tool_call_id ]
146+ if message .additional_kwargs .get ('name' , '' ):
147+ name = message .additional_kwargs .get ('name' , '' )
148+ tool_call = ToolCall (id = message .tool_call_id , type = message .type , function = ToolFunction (name = name ))
130149 self ._messages .append (Message (role = message .type , content = message .content , tool_calls = [tool_call ]))
131150 else :
132151 self ._messages .append (Message (role = message .type , content = message .content ))
@@ -161,7 +180,7 @@ def to_json(self):
161180 for i , generation in enumerate (self .generations ):
162181 choice : Choice = None
163182 if isinstance (generation , ChatGeneration ):
164- tool_calls = convert_tool_calls (generation .message .additional_kwargs .get ('tool_calls' , []))
183+ tool_calls = convert_tool_calls_by_additional_kwargs (generation .message .additional_kwargs .get ('tool_calls' , []))
165184 if len (tool_calls ) == 0 and 'function_call' in generation .message .additional_kwargs :
166185 function_call = generation .message .additional_kwargs .get ('function_call' , {})
167186 function = ToolFunction (name = function_call .get ('name' , '' ), arguments = json .loads (function_call .get ('arguments' , {})))
@@ -178,9 +197,17 @@ def to_json(self):
178197 ensure_ascii = False )
179198
180199
181- def convert_tool_calls (tool_calls : list ) -> List [ToolCall ]:
200+ def convert_tool_calls_by_raw (tool_calls : list ) -> List [ToolCall ]:
201+ format_tool_calls : List [ToolCall ] = []
202+ for tool_call in tool_calls :
203+ function = ToolFunction (name = tool_call .get ('name' , '' ), arguments = tool_call .get ('args' , {}))
204+ format_tool_calls .append (ToolCall (id = tool_call .get ('id' , '' ), type = tool_call .get ('type' , '' ), function = function ))
205+ return format_tool_calls
206+
207+
208+ def convert_tool_calls_by_additional_kwargs (tool_calls : list ) -> List [ToolCall ]:
182209 format_tool_calls : List [ToolCall ] = []
183210 for tool_call in tool_calls :
184- function = ToolFunction (name = tool_call .get ('function' , {}).get ('name' , '' ), arguments = json .loads (tool_call .get ('function' , {}).get ('arguments' , {} )))
211+ function = ToolFunction (name = tool_call .get ('function' , {}).get ('name' , '' ), arguments = json .loads (tool_call .get ('function' , {}).get ('arguments' , '{}' )))
185212 format_tool_calls .append (ToolCall (id = tool_call .get ('id' , '' ), type = tool_call .get ('type' , '' ), function = function ))
186213 return format_tool_calls
0 commit comments