Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def llm_call(
generation_llm, prompt, all_callbacks
)

_store_reasoning_traces(response)
_store_tool_calls(response)
_store_response_metadata(response)
return _extract_content(response)
Expand Down Expand Up @@ -172,6 +173,18 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
return dicts_to_messages(prompt)


def _store_reasoning_traces(response) -> None:
if hasattr(response, "additional_kwargs"):
additional_kwargs = response.additional_kwargs
if (
isinstance(additional_kwargs, dict)
and "reasoning_content" in additional_kwargs
):
reasoning_content = additional_kwargs["reasoning_content"]
if reasoning_content:
reasoning_trace_var.set(reasoning_content)


def _store_tool_calls(response) -> None:
"""Extract and store tool calls from response in context."""
tool_calls = getattr(response, "tool_calls", None)
Expand All @@ -192,15 +205,6 @@ def _store_response_metadata(response) -> None:
metadata[field_name] = getattr(response, field_name)
llm_response_metadata_var.set(metadata)

if hasattr(response, "additional_kwargs"):
additional_kwargs = response.additional_kwargs
if (
isinstance(additional_kwargs, dict)
and "reasoning_content" in additional_kwargs
):
reasoning_content = additional_kwargs["reasoning_content"]
if reasoning_content:
reasoning_trace_var.set(reasoning_content)
else:
llm_response_metadata_var.set(None)

Expand Down Expand Up @@ -704,6 +708,12 @@ def extract_tool_calls_from_events(events: list) -> Optional[list]:
return None


def extract_bot_thinking_from_events(events: list):
for event in events:
if event.get("type") == "BotThinking":
return event.get("content")


def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
"""Get the current response metadata and clear it from the context.

Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@

from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import (
extract_bot_thinking_from_events,
extract_tool_calls_from_events,
get_and_clear_reasoning_trace_contextvar,
get_and_clear_response_metadata_contextvar,
get_colang_history,
)
Expand Down Expand Up @@ -1037,7 +1037,7 @@ async def generate_async(
else:
res = GenerationResponse(response=[new_message])

if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
if reasoning_trace := extract_bot_thinking_from_events(events):
if prompt:
# For prompt mode, response should be a string
if isinstance(res.response, str):
Expand Down Expand Up @@ -1182,7 +1182,7 @@ async def generate_async(
else:
# If a prompt is used, we only return the content of the message.

if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
if reasoning_trace := extract_bot_thinking_from_events(events):
new_message["content"] = reasoning_trace + new_message["content"]

if prompt:
Expand Down