From 17483ce1efd5cf2e9a999f526b7c55907dd6deb2 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Wed, 1 Oct 2025 08:32:12 +0200 Subject: [PATCH] feat: add reasoning trace extraction from llm calls --- nemoguardrails/actions/llm/utils.py | 28 +++++++++++++++++++--------- nemoguardrails/rails/llm/llmrails.py | 6 +++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index b71adb76c..1cf6fc388 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -110,6 +110,7 @@ async def llm_call( generation_llm, prompt, all_callbacks ) + _store_reasoning_traces(response) _store_tool_calls(response) _store_response_metadata(response) return _extract_content(response) @@ -172,6 +173,18 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: return dicts_to_messages(prompt) +def _store_reasoning_traces(response) -> None: + if hasattr(response, "additional_kwargs"): + additional_kwargs = response.additional_kwargs + if ( + isinstance(additional_kwargs, dict) + and "reasoning_content" in additional_kwargs + ): + reasoning_content = additional_kwargs["reasoning_content"] + if reasoning_content: + reasoning_trace_var.set(reasoning_content) + + def _store_tool_calls(response) -> None: """Extract and store tool calls from response in context.""" tool_calls = getattr(response, "tool_calls", None) @@ -192,15 +205,6 @@ def _store_response_metadata(response) -> None: metadata[field_name] = getattr(response, field_name) llm_response_metadata_var.set(metadata) - if hasattr(response, "additional_kwargs"): - additional_kwargs = response.additional_kwargs - if ( - isinstance(additional_kwargs, dict) - and "reasoning_content" in additional_kwargs - ): - reasoning_content = additional_kwargs["reasoning_content"] - if reasoning_content: - reasoning_trace_var.set(reasoning_content) else: llm_response_metadata_var.set(None) @@ -704,6 +708,12 @@ def extract_tool_calls_from_events(events: list) -> Optional[list]: return None +def extract_bot_thinking_from_events(events: list): + for event in events: + if event.get("type") == "BotThinking": + return event.get("content") + + def get_and_clear_response_metadata_contextvar() -> Optional[dict]: """Get the current response metadata and clear it from the context. diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index fe56bcf08..e736a32df 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -43,8 +43,8 @@ from nemoguardrails.actions.llm.generation import LLMGenerationActions from nemoguardrails.actions.llm.utils import ( + extract_bot_thinking_from_events, extract_tool_calls_from_events, - get_and_clear_reasoning_trace_contextvar, get_and_clear_response_metadata_contextvar, get_colang_history, ) @@ -1037,7 +1037,7 @@ async def generate_async( else: res = GenerationResponse(response=[new_message]) - if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): + if reasoning_trace := extract_bot_thinking_from_events(events): if prompt: # For prompt mode, response should be a string if isinstance(res.response, str): @@ -1182,7 +1182,7 @@ async def generate_async( else: # If a prompt is used, we only return the content of the message. - if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): + if reasoning_trace := extract_bot_thinking_from_events(events): new_message["content"] = reasoning_trace + new_message["content"] if prompt: