Skip to content

Commit 83e04dd

Browse files
committed
fix(runnable-rails): preserve message metadata in RunnableRails tool calling (#1405)
* fix: preserve message metadata in RunnableRails tool calling - Extract message conversion logic to centralized message_utils module - Dynamically preserve all LangChain message fields (tool_calls, additional_kwargs, etc.) - Fix tool calling metadata loss in passthrough mode - Add comprehensive unit tests for message conversions
1 parent 7680ce6 commit 83e04dd

File tree

6 files changed

+898
-124
lines changed

6 files changed

+898
-124
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919
from langchain.base_language import BaseLanguageModel
2020
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
21-
from langchain.prompts.base import StringPromptValue
22-
from langchain.prompts.chat import ChatPromptValue
23-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
2421

2522
from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
2623
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
@@ -30,6 +27,7 @@
3027
reasoning_trace_var,
3128
tool_calls_var,
3229
)
30+
from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages
3331
from nemoguardrails.logging.callbacks import logging_callbacks
3432
from nemoguardrails.logging.explain import LLMCallInfo
3533

@@ -146,34 +144,7 @@ async def _invoke_with_message_list(
146144

147145
def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
148146
"""Convert message list to LangChain message format."""
149-
messages = []
150-
for msg in prompt:
151-
msg_type = msg["type"] if "type" in msg else msg["role"]
152-
153-
if msg_type == "user":
154-
messages.append(HumanMessage(content=msg["content"]))
155-
elif msg_type in ["bot", "assistant"]:
156-
tool_calls = msg.get("tool_calls")
157-
if tool_calls:
158-
messages.append(
159-
AIMessage(content=msg["content"], tool_calls=tool_calls)
160-
)
161-
else:
162-
messages.append(AIMessage(content=msg["content"]))
163-
elif msg_type == "system":
164-
messages.append(SystemMessage(content=msg["content"]))
165-
elif msg_type == "tool":
166-
tool_message = ToolMessage(
167-
content=msg["content"],
168-
tool_call_id=msg.get("tool_call_id", ""),
169-
)
170-
if msg.get("name"):
171-
tool_message.name = msg["name"]
172-
messages.append(tool_message)
173-
else:
174-
raise ValueError(f"Unknown message type {msg_type}")
175-
176-
return messages
147+
return dicts_to_messages(prompt)
177148

178149

179150
def _store_tool_calls(response) -> None:
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utilities for converting between LangChain messages and dictionary format."""
17+
18+
from typing import Any, Dict, List, Optional, Type
19+
20+
from langchain_core.messages import (
21+
AIMessage,
22+
AIMessageChunk,
23+
BaseMessage,
24+
HumanMessage,
25+
SystemMessage,
26+
ToolMessage,
27+
)
28+
29+
30+
def get_message_role(msg: BaseMessage) -> str:
31+
"""Get the role string for a BaseMessage."""
32+
if isinstance(msg, AIMessage):
33+
return "assistant"
34+
elif isinstance(msg, HumanMessage):
35+
return "user"
36+
elif isinstance(msg, SystemMessage):
37+
return "system"
38+
elif isinstance(msg, ToolMessage):
39+
return "tool"
40+
else:
41+
return getattr(msg, "type", "user")
42+
43+
44+
def get_message_class(msg_type: str) -> Type[BaseMessage]:
45+
"""Get the appropriate message class for a given type/role."""
46+
if msg_type == "user":
47+
return HumanMessage
48+
elif msg_type in ["bot", "assistant"]:
49+
return AIMessage
50+
elif msg_type in ["system", "developer"]:
51+
return SystemMessage
52+
elif msg_type == "tool":
53+
return ToolMessage
54+
else:
55+
raise ValueError(f"Unknown message type: {msg_type}")
56+
57+
58+
def message_to_dict(msg: BaseMessage) -> Dict[str, Any]:
59+
"""
60+
Convert a BaseMessage to dictionary format, preserving all model fields.
61+
62+
Args:
63+
msg: The BaseMessage to convert
64+
65+
Returns:
66+
Dictionary representation with role, content, and all other fields
67+
"""
68+
result = {"role": get_message_role(msg), "content": msg.content}
69+
70+
if isinstance(msg, ToolMessage):
71+
result["tool_call_id"] = msg.tool_call_id
72+
73+
exclude_fields = {"type", "content", "example"}
74+
75+
if hasattr(msg, "model_fields"):
76+
for field_name in msg.model_fields:
77+
if field_name not in exclude_fields and field_name not in result:
78+
value = getattr(msg, field_name, None)
79+
if value is not None:
80+
result[field_name] = value
81+
82+
return result
83+
84+
85+
def dict_to_message(msg_dict: Dict[str, Any]) -> BaseMessage:
86+
"""
87+
Convert a dictionary to the appropriate BaseMessage type.
88+
89+
Args:
90+
msg_dict: Dictionary with role/type, content, and optional fields
91+
92+
Returns:
93+
The appropriate BaseMessage instance
94+
"""
95+
msg_type = msg_dict.get("type") or msg_dict.get("role")
96+
if not msg_type:
97+
raise ValueError("Message dictionary must have 'type' or 'role' field")
98+
99+
content = msg_dict.get("content", "")
100+
message_class = get_message_class(msg_type)
101+
102+
exclude_keys = {"role", "type", "content"}
103+
104+
valid_fields = (
105+
set(message_class.model_fields.keys())
106+
if hasattr(message_class, "model_fields")
107+
else set()
108+
)
109+
110+
kwargs = {
111+
k: v
112+
for k, v in msg_dict.items()
113+
if k not in exclude_keys and k in valid_fields and v is not None
114+
}
115+
116+
if message_class == ToolMessage:
117+
kwargs["tool_call_id"] = msg_dict.get("tool_call_id", "")
118+
119+
return message_class(content=content, **kwargs)
120+
121+
122+
def messages_to_dicts(messages: List[BaseMessage]) -> List[Dict[str, Any]]:
123+
"""
124+
Convert a list of BaseMessage objects to dictionary format.
125+
126+
Args:
127+
messages: List of BaseMessage objects
128+
129+
Returns:
130+
List of dictionary representations
131+
"""
132+
return [message_to_dict(msg) for msg in messages]
133+
134+
135+
def dicts_to_messages(msg_dicts: List[Dict[str, Any]]) -> List[BaseMessage]:
136+
"""
137+
Convert a list of dictionaries to BaseMessage objects.
138+
139+
Args:
140+
msg_dicts: List of message dictionaries
141+
142+
Returns:
143+
List of appropriate BaseMessage instances
144+
"""
145+
return [dict_to_message(msg_dict) for msg_dict in msg_dicts]
146+
147+
148+
def is_message_type(obj: Any, message_type: Type[BaseMessage]) -> bool:
149+
"""Check if an object is an instance of a specific message type."""
150+
return isinstance(obj, message_type)
151+
152+
153+
def is_base_message(obj: Any) -> bool:
154+
"""Check if an object is any type of BaseMessage."""
155+
return isinstance(obj, BaseMessage)
156+
157+
158+
def is_ai_message(obj: Any) -> bool:
159+
"""Check if an object is an AIMessage."""
160+
return isinstance(obj, AIMessage)
161+
162+
163+
def is_human_message(obj: Any) -> bool:
164+
"""Check if an object is a HumanMessage."""
165+
return isinstance(obj, HumanMessage)
166+
167+
168+
def is_system_message(obj: Any) -> bool:
169+
"""Check if an object is a SystemMessage."""
170+
return isinstance(obj, SystemMessage)
171+
172+
173+
def is_tool_message(obj: Any) -> bool:
174+
"""Check if an object is a ToolMessage."""
175+
return isinstance(obj, ToolMessage)
176+
177+
178+
def all_base_messages(items: List[Any]) -> bool:
179+
"""Check if all items in a list are BaseMessage instances."""
180+
return all(isinstance(item, BaseMessage) for item in items)
181+
182+
183+
def create_ai_message(
184+
content: str,
185+
tool_calls: Optional[list] = None,
186+
additional_kwargs: Optional[dict] = None,
187+
response_metadata: Optional[dict] = None,
188+
id: Optional[str] = None,
189+
name: Optional[str] = None,
190+
usage_metadata: Optional[dict] = None,
191+
**extra_kwargs,
192+
) -> AIMessage:
193+
"""Create an AIMessage with optional fields."""
194+
kwargs = {}
195+
if tool_calls is not None:
196+
kwargs["tool_calls"] = tool_calls
197+
if additional_kwargs is not None:
198+
kwargs["additional_kwargs"] = additional_kwargs
199+
if response_metadata is not None:
200+
kwargs["response_metadata"] = response_metadata
201+
if id is not None:
202+
kwargs["id"] = id
203+
if name is not None:
204+
kwargs["name"] = name
205+
if usage_metadata is not None:
206+
kwargs["usage_metadata"] = usage_metadata
207+
208+
valid_fields = (
209+
set(AIMessage.model_fields.keys())
210+
if hasattr(AIMessage, "model_fields")
211+
else set()
212+
)
213+
for key, value in extra_kwargs.items():
214+
if key in valid_fields and key not in kwargs:
215+
kwargs[key] = value
216+
217+
return AIMessage(content=content, **kwargs)
218+
219+
220+
def create_ai_message_chunk(content: str, **metadata) -> AIMessageChunk:
221+
"""Create an AIMessageChunk with optional metadata."""
222+
return AIMessageChunk(content=content, **metadata)
223+
224+
225+
def create_human_message(
226+
content: str,
227+
additional_kwargs: Optional[dict] = None,
228+
response_metadata: Optional[dict] = None,
229+
id: Optional[str] = None,
230+
name: Optional[str] = None,
231+
) -> HumanMessage:
232+
"""Create a HumanMessage with optional fields."""
233+
kwargs = {}
234+
if additional_kwargs is not None:
235+
kwargs["additional_kwargs"] = additional_kwargs
236+
if response_metadata is not None:
237+
kwargs["response_metadata"] = response_metadata
238+
if id is not None:
239+
kwargs["id"] = id
240+
if name is not None:
241+
kwargs["name"] = name
242+
243+
return HumanMessage(content=content, **kwargs)
244+
245+
246+
def create_system_message(
247+
content: str,
248+
additional_kwargs: Optional[dict] = None,
249+
response_metadata: Optional[dict] = None,
250+
id: Optional[str] = None,
251+
name: Optional[str] = None,
252+
) -> SystemMessage:
253+
"""Create a SystemMessage with optional fields."""
254+
kwargs = {}
255+
if additional_kwargs is not None:
256+
kwargs["additional_kwargs"] = additional_kwargs
257+
if response_metadata is not None:
258+
kwargs["response_metadata"] = response_metadata
259+
if id is not None:
260+
kwargs["id"] = id
261+
if name is not None:
262+
kwargs["name"] = name
263+
264+
return SystemMessage(content=content, **kwargs)
265+
266+
267+
def create_tool_message(
268+
content: str,
269+
tool_call_id: str,
270+
name: Optional[str] = None,
271+
additional_kwargs: Optional[dict] = None,
272+
response_metadata: Optional[dict] = None,
273+
id: Optional[str] = None,
274+
artifact: Optional[Any] = None,
275+
status: Optional[str] = None,
276+
) -> ToolMessage:
277+
"""Create a ToolMessage with optional fields."""
278+
kwargs = {"tool_call_id": tool_call_id}
279+
if name is not None:
280+
kwargs["name"] = name
281+
if additional_kwargs is not None:
282+
kwargs["additional_kwargs"] = additional_kwargs
283+
if response_metadata is not None:
284+
kwargs["response_metadata"] = response_metadata
285+
if id is not None:
286+
kwargs["id"] = id
287+
if artifact is not None:
288+
kwargs["artifact"] = artifact
289+
if status is not None:
290+
kwargs["status"] = status
291+
292+
return ToolMessage(content=content, **kwargs)

0 commit comments

Comments
 (0)