Skip to content

feat(mappers): cloudwatch change for openinference#174

Open
poshinchen wants to merge 2 commits intostrands-agents:mainfrom
poshinchen:feat/mappers-cloudwatch-change-openinference
Open

feat(mappers): cloudwatch change for openinference#174
poshinchen wants to merge 2 commits intostrands-agents:mainfrom
poshinchen:feat/mappers-cloudwatch-change-openinference

Conversation

@poshinchen
Copy link
Contributor

Description

constants.py

As discussed in the previous PR, moved the internal variables into constants.py

CloudWatchProvider

  • Added end_time parameter to narrow the query time window (defaults to now()).

OpenInferenceSessionMapper: Multi-Agent LangGraph Fixes

  • Fixed user content extraction
    • iterates llm.input_messages by role instead of blindly grabbing index 0 (which was the system prompt, not the user message).
  • Fixed tool parameter parsing
    • uses ast.literal_eval as fallback for Python repr strings (OpenInference stores input.value with single quotes).
  • Fixed missing AgentInvocationSpan
    • backward search now skips empty AI messages in LangGraph output.
  • Simplified agent invocation detection to CHAIN + name=LangGraph only, removed false-positive-prone AGENT + langgraph_step heuristic.
  • Added multi-agent dedup — when multiple LangGraph CHAIN spans exist (nested sub-graphs), keeps only the root (last one).
  • Added ADOT/CloudWatch support
    • detects root LangGraph node via "messages" without "remaining_steps" in input body (intermediate nodes always have remaining_steps).
  • Filters empty-response InferenceSpans
    • drops LangGraph internal re-invocations where the assistant returns only empty text with no tool calls.

LangChainOtelSessionMapper: ADOT Fixes

  • moved the internal variables into constants.py

What tests have been executed

  • (in_memory and ADOT) Verified with real OpenInference multi-agent with tools, single-turn LangGraph traces
  • (in_memory and ADOT) Verified with real OpenInference multi-agent with tools, multi-turn traces
    Both in-memory and CloudWatch paths produce valid sessions with correct user prompts, tool calls, and agent responses.
    The evaluators return with valid scores and reasoning.

Related Issues

#91

Documentation PR

N/A, TBD

Type of Change

Bug fix
New feature

Testing

How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli

  • I ran hatch run prepare

Checklist

  • I have read the CONTRIBUTING document
  • I have added any necessary tests that prove my fix is effective or my feature works
  • I have updated the documentation accordingly
  • I have added an appropriate example to the documentation to outline the feature, or no new docs are needed
  • My changes generate no new warnings
  • Any dependent changes have been merged and published

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@poshinchen poshinchen changed the title (WIP) Feat/mappers cloudwatch change openinference feat(mappers): cloudwatch change for openinference Mar 23, 2026
@poshinchen
Copy link
Contributor Author

I used the following code for testing the performance (multi-agent, multi-tools).

"""
LangChain Multi-Agent Evaluation - Openinference Instrumentation

Creates a multi-agent LangGraph system with a supervisor routing to specialist
agents (math, research, writer), each with multiple tools. Runs the agents,
captures OTEL traces via Openinference's opentelemetry-instrumentation-langchain,
and evaluates using strands-evals.

This generates rich, deep traces with:
- Supervisor routing decisions
- Nested ReAct agent tool-calling loops
- Multiple tool invocations per request
- Multi-hop agent chains (e.g., research -> writer)

Requirements:
    pip install langchain langchain-aws langgraph openinference-instrumentation-langchain opentelemetry-sdk
"""

import json
import math
import random
from datetime import datetime
from typing import Any

from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import create_react_agent
from opentelemetry import trace
from openinference.instrumentation.langchain import LangChainInstrumentor
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from strands_evals import Case, Experiment
from strands_evals.types.evaluation_report import EvaluationReport
from strands_evals.evaluators import ToolSelectionAccuracyEvaluator, HelpfulnessEvaluator, GoalSuccessRateEvaluator
from strands_evals.mappers import detect_otel_mapper, readable_spans_to_dicts


# =============================================================================
# Math Tools
# =============================================================================

@tool
def add_numbers(a: int, b: int) -> int:
    """Return the sum of two numbers."""
    return a + b


@tool
def multiply_numbers(a: int, b: int) -> int:
    """Return the product of two numbers."""
    return a * b


@tool
def divide_numbers(a: float, b: float) -> str:
    """Return a divided by b. Returns error message if b is zero."""
    if b == 0:
        return "Error: division by zero"
    return str(a / b)


@tool
def square_root(n: float) -> str:
    """Return the square root of a non-negative number."""
    if n < 0:
        return "Error: cannot take square root of negative number"
    return str(math.sqrt(n))


math_tools = [add_numbers, multiply_numbers, divide_numbers, square_root]


# =============================================================================
# Research / Knowledge Tools
# =============================================================================

@tool
def get_weather(city: str) -> str:
    """Get the current weather for a city.

    Args:
        city: The city name to get weather for

    Returns:
        Weather information string
    """
    weather_data = {
        "seattle": "Rainy, 55F",
        "new york": "Sunny, 72F",
        "london": "Cloudy, 60F",
        "tokyo": "Clear, 68F",
        "paris": "Partly cloudy, 63F",
        "sydney": "Warm, 82F",
    }
    city_lower = city.lower()
    for c, w in weather_data.items():
        if c in city_lower:
            return f"Weather in {city}: {w}"
    return f"Weather in {city}: Partly cloudy, 65F"


@tool
def lookup_stock_price(ticker: str) -> str:
    """Look up the current stock price for a given ticker symbol. Returns simulated data."""
    random.seed(hash(ticker))
    price = round(random.uniform(10, 500), 2)
    change_pct = round(random.uniform(-5, 5), 2)
    return json.dumps({
        "ticker": ticker.upper(),
        "price_usd": price,
        "change_pct": change_pct,
        "timestamp": datetime.utcnow().isoformat(),
    })


@tool
def search_knowledge_base(query: str) -> str:
    """Search an internal knowledge base for information on a topic. Returns simulated results."""
    return json.dumps({
        "query": query,
        "results": [
            {"title": f"Overview of {query}", "snippet": f"{query} is a broad topic with many facets."},
            {"title": f"Advanced {query}", "snippet": f"Recent research in {query} has led to breakthroughs."},
        ],
        "total_results": 2,
    })


research_tools = [get_weather, lookup_stock_price, search_knowledge_base]


# =============================================================================
# Writing / Formatting Tools
# =============================================================================

@tool
def summarize_text(text: str) -> str:
    """Summarize a given text into a few key sentences."""
    sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()]
    summary = ". ".join(sentences[:3])
    return summary + "." if summary else text


@tool
def translate_text(text: str, target_language: str) -> str:
    """Translate text to a target language. Returns a placeholder translation with metadata."""
    return json.dumps({
        "original": text[:200],
        "target_language": target_language,
        "translated": f"[Simulated {target_language} translation of: {text[:80]}...]",
    })


writer_tools = [summarize_text, translate_text]


# =============================================================================
# Supervisor Tools
# =============================================================================

@tool
def route_to_agent(agent_name: str) -> str:
    """Route the current task to a specialist agent. Valid names: math_agent, research_agent, writer_agent."""
    valid = {"math_agent", "research_agent", "writer_agent"}
    if agent_name not in valid:
        return f"Error: unknown agent '{agent_name}'. Valid: {', '.join(sorted(valid))}"
    return f"Routing to {agent_name}"


@tool
def finish(answer: str) -> str:
    """Complete the task and return the final answer to the user."""
    return answer


supervisor_tools = [route_to_agent, finish]


# =============================================================================
# Setup OTEL Tracing
# =============================================================================

memory_exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(memory_exporter))
trace.set_tracer_provider(provider)

LangChainInstrumentor().instrument()


# =============================================================================
# Build Multi-Agent Graph
# =============================================================================

def create_llm():
    return ChatBedrock(
        model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
        region_name="us-west-2",
        model_kwargs={"temperature": 0},
    )


def build_multi_agent_graph():
    """Build a LangGraph multi-agent workflow with supervisor routing to specialists."""
    llm = create_llm()

    math_agent = create_react_agent(
        llm,
        tools=math_tools,
        prompt="You are a math specialist. Use your tools to solve math problems. Show your work.",
    )

    research_agent = create_react_agent(
        llm,
        tools=research_tools,
        prompt="You are a research specialist. Use your tools to look up weather, stock prices, and search knowledge bases.",
    )

    writer_agent = create_react_agent(
        llm,
        tools=writer_tools,
        prompt="You are a writing specialist. Summarize, translate, and polish text clearly.",
    )

    class SupervisorState(MessagesState):
        next_agent: str
        final_answer: str

    def _filter_non_system(messages):
        return [m for m in messages if not isinstance(m, SystemMessage)]

    def supervisor_node(state: SupervisorState):
        messages = state["messages"]
        supervisor_prompt = SystemMessage(content=(
            "You are a supervisor agent coordinating specialists:\n"
            "- 'math_agent': calculations, arithmetic, numerical problems\n"
            "- 'research_agent': weather, stock prices, knowledge search\n"
            "- 'writer_agent': summarizing, translating, polishing text\n\n"
            "Call 'route_to_agent' with the specialist name, or 'finish' with the final answer.\n"
            "You may chain multiple agents for complex tasks."
        ))

        supervisor = create_react_agent(llm, tools=supervisor_tools)
        filtered = [m for m in messages if not isinstance(m, SystemMessage)]
        result = supervisor.invoke({"messages": [supervisor_prompt] + filtered})
        last_msg = result["messages"][-1].content

        for msg in reversed(result["messages"]):
            if hasattr(msg, "tool_calls") and msg.tool_calls:
                for tc in msg.tool_calls:
                    if tc["name"] == "route_to_agent":
                        agent_name = tc["args"].get("agent_name", "")
                        return {"next_agent": agent_name, "messages": result["messages"]}
                    elif tc["name"] == "finish":
                        answer = tc["args"].get("answer", last_msg)
                        return {"next_agent": "FINISH", "final_answer": answer, "messages": result["messages"]}

        return {"next_agent": "FINISH", "final_answer": last_msg, "messages": result["messages"]}

    def math_node(state: SupervisorState):
        result = math_agent.invoke({"messages": _filter_non_system(state["messages"])})
        return {"messages": result["messages"], "next_agent": "supervisor"}

    def research_node(state: SupervisorState):
        result = research_agent.invoke({"messages": _filter_non_system(state["messages"])})
        return {"messages": result["messages"], "next_agent": "supervisor"}

    def writer_node(state: SupervisorState):
        result = writer_agent.invoke({"messages": _filter_non_system(state["messages"])})
        return {"messages": result["messages"], "next_agent": "supervisor"}

    def route_from_supervisor(state: SupervisorState):
        next_agent = state.get("next_agent", "FINISH")
        if next_agent in ("math_agent", "research_agent", "writer_agent"):
            return next_agent
        return END

    graph = StateGraph(SupervisorState)
    graph.add_node("supervisor", supervisor_node)
    graph.add_node("math_agent", math_node)
    graph.add_node("research_agent", research_node)
    graph.add_node("writer_agent", writer_node)

    graph.add_edge(START, "supervisor")
    graph.add_conditional_edges("supervisor", route_from_supervisor, {
        "math_agent": "math_agent",
        "research_agent": "research_agent",
        "writer_agent": "writer_agent",
        END: END,
    })
    graph.add_edge("math_agent", "supervisor")
    graph.add_edge("research_agent", "supervisor")
    graph.add_edge("writer_agent", "supervisor")

    return graph.compile()


# =============================================================================
# Run Agent and Capture Traces
# =============================================================================

def run_agent_and_capture(query: str) -> tuple[str, list[dict]]:
    """Run the multi-agent system and return response + captured spans."""
    memory_exporter.clear()

    graph = build_multi_agent_graph()
    result = graph.invoke({"messages": [HumanMessage(content=query)]})

    final_answer = result.get("final_answer") or ""
    if not final_answer:
        messages = result.get("messages", [])
        final_answer = messages[-1].content if messages else ""

    spans = readable_spans_to_dicts(memory_exporter.get_finished_spans())
    return final_answer, spans


# =============================================================================
# Debug: Print Raw Trace Structure
# =============================================================================

def print_trace_structure(spans: list[dict]):
    """Print the trace structure for debugging."""
    print("\n--- RAW TRACE STRUCTURE ---")
    print(f"Total spans: {len(spans)}")

    scopes = set(s.get("scope", {}).get("name", "unknown") for s in spans)
    print(f"Scopes: {scopes}")

    for i, span in enumerate(spans):
        name = span.get("name", "?")
        scope = span.get("scope", {}).get("name", "?")
        attrs = span.get("attributes", {})
        print(f"\n  [{i}] {name}  (scope: {scope})")
        for k in list(attrs.keys())[:10]:
            v = attrs[k]
            if isinstance(v, str) and len(v) > 80:
                v = v[:80] + "..."
            print(f"       {k}: {v}")


# =============================================================================
# Run Evaluation
# =============================================================================

def run_evaluation():
    """Run strands-evals evaluation on multi-agent traces."""
    print("\n" + "=" * 60)
    print("RUNNING STRANDS-EVALS EVALUATION (MULTI-AGENT)")
    print("=" * 60)

    test_cases = [
        Case[str, str](
            name="weather and diff",
            input="What's the weather difference between New York and Seattle?",
            # metadata={"expected_tool": "square_root"},
        ),
    ]

    def task_function(case: Case) -> dict[str, Any]:
        response, spans = run_agent_and_capture(case.input)

        # Save raw spans to JSON for debugging
        spans_file = f"debug_raw_spans_openinference_{case.name}.json"
        with open(spans_file, "w") as f:
            json.dump(spans, f, indent=2, default=str)
        print(f"Saved raw spans to: {spans_file}")

        mapper = detect_otel_mapper(spans)
        session = mapper.map_to_session(spans, session_id=case.session_id)

        # Save session to JSON for debugging
        session_file = f"debug_session_openinference_{case.name}.json"
        session_dict = session.model_dump() if hasattr(session, "model_dump") else session.dict()
        with open(session_file, "w") as f:
            json.dump(session_dict, f, indent=2, default=str)
        print(f"Saved session to: {session_file}")

        return {
            "output": response,
            "trajectory": session,
        }

    evaluators=[ HelpfulnessEvaluator(), ToolSelectionAccuracyEvaluator(), GoalSuccessRateEvaluator(),]
    experiment = Experiment[str, str](cases=test_cases, evaluators=evaluators)

    print("\nRunning evaluations...")
    reports = experiment.run_evaluations(task_function)

    print("\n" + "=" * 60)
    print("EVALUATION RESULTS")
    print("=" * 60)
    for report in reports:
        print(f"\nOverall Score: {report.overall_score:.2f}")
        pass_count = sum(report.test_passes)
        total = len(report.test_passes)
        print(f"Pass Rate: {pass_count}/{total} ({pass_count/total*100:.0f}%)")
        print("\nTest Cases:")
        for i, (case, score, passed, reason) in enumerate(
            zip(report.cases, report.scores, report.test_passes, report.reasons)
        ):
            status = "PASS" if passed else "FAIL"
            case_name = case.get("name", f"case-{i}")
            case_input = str(case.get("input", ""))[:50]
            print(f"  [{status}] {case_name}: {score:.2f}")
            print(f"         Input: {case_input}...")
            print(f"         Reason: {reason[:80]}...")
    
    flatten_report = EvaluationReport.flatten(reports)
    flatten_report.run_display(include_actual_output=True)


# =============================================================================
# Main
# =============================================================================

def main():
    print("=" * 60)
    print("LANGCHAIN MULTI-AGENT EVALUATION - Openinference")
    print("=" * 60)
    print("\nMulti-agent system: supervisor -> math/research/writer agents")
    print("Tools: add_numbers, multiply_numbers, divide_numbers, square_root,")
    print("       get_weather, lookup_stock_price, search_knowledge_base,")
    print("       summarize_text, translate_text")

    run_evaluation()


if __name__ == "__main__":
    main()

@poshinchen
Copy link
Contributor Author

And I used the following code for multi-turn testing:

import json
from datetime import datetime, timedelta
from typing import Any

from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from opentelemetry import trace
from openinference.instrumentation.langchain import LangChainInstrumentor
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from strands_evals import Case, Experiment
from strands_evals.types.evaluation_report import EvaluationReport
from strands_evals.evaluators import (
    ToolSelectionAccuracyEvaluator,
    HelpfulnessEvaluator,
    GoalSuccessRateEvaluator,
)
from strands_evals.mappers import detect_otel_mapper, readable_spans_to_dicts


# =============================================================================
# Calculator Tools
# =============================================================================

@tool
def calculate(expression: str) -> str:
    """Evaluate a mathematical expression. Supports +, -, *, /, **, ().
    
    Args:
        expression: A math expression like "2 + 3 * 4" or "(10 - 2) / 4"
    
    Returns:
        The result as a string, or an error message
    """
    try:
        # Safe eval with limited operations
        allowed = set("0123456789+-*/().** ")
        if not all(c in allowed for c in expression):
            return f"Error: Invalid characters in expression"
        result = eval(expression)
        return str(result)
    except ZeroDivisionError:
        return "Error: Division by zero"
    except Exception as e:
        return f"Error: {str(e)}"


@tool
def convert_units(value: float, from_unit: str, to_unit: str) -> str:
    """Convert between common units.
    
    Supported conversions:
    - Length: km <-> miles, m <-> feet, cm <-> inches
    - Weight: kg <-> lbs
    - Temperature: celsius <-> fahrenheit
    
    Args:
        value: The numeric value to convert
        from_unit: Source unit (e.g., "km", "celsius")
        to_unit: Target unit (e.g., "miles", "fahrenheit")
    
    Returns:
        Converted value as string, or error message
    """
    conversions = {
        ("km", "miles"): lambda x: x * 0.621371,
        ("miles", "km"): lambda x: x / 0.621371,
        ("m", "feet"): lambda x: x * 3.28084,
        ("feet", "m"): lambda x: x / 3.28084,
        ("cm", "inches"): lambda x: x / 2.54,
        ("inches", "cm"): lambda x: x * 2.54,
        ("kg", "lbs"): lambda x: x * 2.20462,
        ("lbs", "kg"): lambda x: x / 2.20462,
        ("celsius", "fahrenheit"): lambda x: x * 9/5 + 32,
        ("fahrenheit", "celsius"): lambda x: (x - 32) * 5/9,
    }
    
    key = (from_unit.lower(), to_unit.lower())
    if key not in conversions:
        return f"Error: Cannot convert from {from_unit} to {to_unit}"
    
    result = conversions[key](value)
    return f"{value} {from_unit} = {result:.2f} {to_unit}"


# =============================================================================
# Date/Time Tools
# =============================================================================

@tool
def get_current_datetime() -> str:
    """Get the current date and time in ISO format."""
    return datetime.now().isoformat()


@tool
def calculate_date_difference(date1: str, date2: str) -> str:
    """Calculate the difference between two dates.
    
    Args:
        date1: First date in YYYY-MM-DD format
        date2: Second date in YYYY-MM-DD format
    
    Returns:
        Number of days between the dates
    """
    try:
        d1 = datetime.strptime(date1, "%Y-%m-%d")
        d2 = datetime.strptime(date2, "%Y-%m-%d")
        diff = abs((d2 - d1).days)
        return f"{diff} days between {date1} and {date2}"
    except ValueError as e:
        return f"Error: Invalid date format. Use YYYY-MM-DD. Details: {e}"


@tool
def add_days_to_date(date: str, days: int) -> str:
    """Add or subtract days from a date.
    
    Args:
        date: Starting date in YYYY-MM-DD format
        days: Number of days to add (negative to subtract)
    
    Returns:
        The resulting date
    """
    try:
        d = datetime.strptime(date, "%Y-%m-%d")
        result = d + timedelta(days=days)
        return result.strftime("%Y-%m-%d")
    except ValueError as e:
        return f"Error: Invalid date format. Use YYYY-MM-DD. Details: {e}"


# =============================================================================
# String/Text Tools
# =============================================================================

@tool
def count_words(text: str) -> str:
    """Count the number of words in a text.
    
    Args:
        text: The text to count words in
    
    Returns:
        Word count as string
    """
    words = text.split()
    return f"{len(words)} words"


@tool
def reverse_string(text: str) -> str:
    """Reverse a string.
    
    Args:
        text: The text to reverse
    
    Returns:
        Reversed text
    """
    return text[::-1]


@tool
def extract_numbers(text: str) -> str:
    """Extract all numbers from a text.
    
    Args:
        text: Text containing numbers
    
    Returns:
        JSON list of extracted numbers
    """
    import re
    numbers = re.findall(r'-?\d+\.?\d*', text)
    return json.dumps({"numbers": [float(n) if '.' in n else int(n) for n in numbers]})


# =============================================================================
# Weather & Translation Tools
# =============================================================================

@tool
def get_weather(city: str) -> str:
    """Get the current weather for a city.

    Args:
        city: The city name to get weather for

    Returns:
        Weather information string
    """
    weather_data = {
        "seattle": "Rainy, 55°F (13°C)",
        "new york": "Sunny, 72°F (22°C)",
        "london": "Cloudy, 60°F (16°C)",
        "tokyo": "Clear, 68°F (20°C)",
        "paris": "Partly cloudy, 63°F (17°C)",
        "sydney": "Warm, 82°F (28°C)",
        "taipei": "Humid, 85°F (29°C)",
        "berlin": "Cool, 58°F (14°C)",
    }
    city_lower = city.lower()
    for c, w in weather_data.items():
        if c in city_lower:
            return f"Weather in {city}: {w}"
    return f"Weather in {city}: Partly cloudy, 65°F (18°C)"


@tool
def translate_text(text: str, target_language: str) -> str:
    """Translate text to a target language (simulated).
    
    Args:
        text: The text to translate
        target_language: Target language (e.g., "Spanish", "French", "Japanese")
    
    Returns:
        Simulated translation
    """
    # Simulated translations for demo
    translations = {
        "spanish": {
            "hello": "hola",
            "weather": "clima",
            "rainy": "lluvioso",
            "sunny": "soleado",
            "cloudy": "nublado",
        },
        "french": {
            "hello": "bonjour",
            "weather": "météo",
            "rainy": "pluvieux",
            "sunny": "ensoleillé",
            "cloudy": "nuageux",
        },
        "japanese": {
            "hello": "こんにちは",
            "weather": "天気",
            "rainy": "雨",
            "sunny": "晴れ",
            "cloudy": "曇り",
        },
    }
    
    lang = target_language.lower()
    return f"[{target_language} translation]: {text}"


@tool
def summarize_text(text: str) -> str:
    """Summarize a given text into key points.
    
    Args:
        text: The text to summarize
    
    Returns:
        A brief summary
    """
    # Simple summarization: take first 2 sentences or truncate
    sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()]
    if len(sentences) <= 2:
        return text
    summary = ". ".join(sentences[:2]) + "."
    return f"Summary: {summary}"


# =============================================================================
# Lookup Tools (simulated)
# =============================================================================

@tool
def lookup_country_capital(country: str) -> str:
    """Look up the capital city of a country.
    
    Args:
        country: Country name
    
    Returns:
        Capital city name or error
    """
    capitals = {
        "france": "Paris",
        "germany": "Berlin",
        "japan": "Tokyo",
        "brazil": "Brasília",
        "australia": "Canberra",
        "canada": "Ottawa",
        "italy": "Rome",
        "spain": "Madrid",
        "united states": "Washington, D.C.",
        "united kingdom": "London",
        "china": "Beijing",
        "india": "New Delhi",
    }
    
    result = capitals.get(country.lower())
    if result:
        return f"The capital of {country} is {result}"
    return f"Error: Capital not found for '{country}'"


@tool
def lookup_element(symbol_or_name: str) -> str:
    """Look up information about a chemical element.
    
    Args:
        symbol_or_name: Element symbol (e.g., "H") or name (e.g., "Hydrogen")
    
    Returns:
        Element information as JSON
    """
    elements = {
        "h": {"name": "Hydrogen", "symbol": "H", "atomic_number": 1, "atomic_mass": 1.008},
        "hydrogen": {"name": "Hydrogen", "symbol": "H", "atomic_number": 1, "atomic_mass": 1.008},
        "o": {"name": "Oxygen", "symbol": "O", "atomic_number": 8, "atomic_mass": 15.999},
        "oxygen": {"name": "Oxygen", "symbol": "O", "atomic_number": 8, "atomic_mass": 15.999},
        "c": {"name": "Carbon", "symbol": "C", "atomic_number": 6, "atomic_mass": 12.011},
        "carbon": {"name": "Carbon", "symbol": "C", "atomic_number": 6, "atomic_mass": 12.011},
        "fe": {"name": "Iron", "symbol": "Fe", "atomic_number": 26, "atomic_mass": 55.845},
        "iron": {"name": "Iron", "symbol": "Fe", "atomic_number": 26, "atomic_mass": 55.845},
        "au": {"name": "Gold", "symbol": "Au", "atomic_number": 79, "atomic_mass": 196.967},
        "gold": {"name": "Gold", "symbol": "Au", "atomic_number": 79, "atomic_mass": 196.967},
    }
    
    result = elements.get(symbol_or_name.lower())
    if result:
        return json.dumps(result)
    return f"Error: Element '{symbol_or_name}' not found"


# =============================================================================
# All Tools
# =============================================================================

all_tools = [
    calculate,
    convert_units,
    get_current_datetime,
    calculate_date_difference,
    add_days_to_date,
    count_words,
    reverse_string,
    extract_numbers,
    get_weather,
    translate_text,
    summarize_text,
    lookup_country_capital,
    lookup_element,
]


# =============================================================================
# Setup OTEL Tracing
# =============================================================================

memory_exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(memory_exporter))
trace.set_tracer_provider(provider)

LangChainInstrumentor().instrument()


# =============================================================================
# Build Simple Agent
# =============================================================================

def create_llm():
    return ChatBedrock(
        model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
        region_name="us-west-2",
        model_kwargs={"temperature": 0},
    )


def build_simple_agent():
    """Build a simple ReAct agent with all tools."""
    llm = create_llm()
    
    agent = create_react_agent(
        llm,
        tools=all_tools,
        prompt=(
            "You are a helpful assistant with access to various tools. "
            "Use the appropriate tool(s) to answer the user's question. "
            "If no tool is needed, respond directly. "
            "Be concise and accurate."
        ),
    )
    return agent


# =============================================================================
# Run Agent and Capture Traces
# =============================================================================

def run_agent_and_capture(query: str) -> tuple[str, list[dict]]:
    """Run the agent and return response + captured spans."""
    memory_exporter.clear()
    
    agent = build_simple_agent()
    result = agent.invoke({"messages": [HumanMessage(content=query)]})
    
    messages = result.get("messages", [])
    final_answer = messages[-1].content if messages else ""
    
    spans = readable_spans_to_dicts(memory_exporter.get_finished_spans())
    return final_answer, spans


def run_multi_turn_agent(queries: list[str]) -> tuple[list[str], list[dict]]:
    """Run the agent with multiple turns (queries), maintaining conversation history.
    
    Args:
        queries: List of user queries for each turn
    
    Returns:
        Tuple of (list of responses, all captured spans)
    """
    memory_exporter.clear()
    
    agent = build_simple_agent()
    messages = []
    responses = []
    
    for query in queries:
        messages.append(HumanMessage(content=query))
        result = agent.invoke({"messages": messages})
        
        # Get the new messages from this turn
        result_messages = result.get("messages", [])
        
        # Find the final AI response
        final_answer = result_messages[-1].content if result_messages else ""
        responses.append(final_answer)
        
        # Update messages for next turn (keep conversation history)
        messages = result_messages
    
    spans = readable_spans_to_dicts(memory_exporter.get_finished_spans())
    return responses, spans


# =============================================================================
# Run Evaluation
# =============================================================================

def run_evaluation():
    """Run strands-evals evaluation on simple agent traces."""
    print("\n" + "=" * 60)
    print("RUNNING STRANDS-EVALS EVALUATION (SIMPLE AGENT)")
    print("=" * 60)

    # Single-turn test cases
    single_turn_cases = [
        # --- Calculator cases ---
        Case[str, str](
            name="calc-simple",
            input="What is 15 * 7 + 23?",
        )
    ]
    
    # Multi-turn test cases - input is a list of queries
    multi_turn_cases = [
        # --- Multi-turn: weather then summarize ---
        Case[list[str], str](
            name="multi-turn-weather-summarize",
            input=[
                "What's the weather in Seattle and New York?",
                "Summarize the weather differences between those cities",
            ],
        ),
    ]

    def single_turn_task_function(case: Case) -> dict[str, Any]:
        """Task function for single-turn cases."""
        response, spans = run_agent_and_capture(case.input)

        # Save raw spans for debugging
        spans_file = f"debug_raw_spans_simple_{case.name}.json"
        with open(spans_file, "w") as f:
            json.dump(spans, f, indent=2, default=str)
        print(f"Saved raw spans to: {spans_file}")

        mapper = detect_otel_mapper(spans)
        session = mapper.map_to_session(spans, session_id=case.session_id)

        # Save session for debugging
        session_file = f"debug_session_simple_{case.name}.json"
        session_dict = session.model_dump() if hasattr(session, "model_dump") else session.dict()
        with open(session_file, "w") as f:
            json.dump(session_dict, f, indent=2, default=str)
        print(f"Saved session to: {session_file}")

        return {
            "output": response,
            "trajectory": session,
        }

    def multi_turn_task_function(case: Case) -> dict[str, Any]:
        """Task function for multi-turn cases - calls agent multiple times with same session."""
        queries = case.input  # List of queries
        responses, spans = run_multi_turn_agent(queries)

        # Save raw spans for debugging
        spans_file = f"debug_raw_spans_multiturn_{case.name}.json"
        with open(spans_file, "w") as f:
            json.dump(spans, f, indent=2, default=str)
        print(f"Saved raw spans to: {spans_file}")

        mapper = detect_otel_mapper(spans)
        # Use the same session_id for all turns - this creates a multi-turn session
        session = mapper.map_to_session(spans, session_id=case.session_id)

        # Save session for debugging
        session_file = f"debug_session_multiturn_{case.name}.json"
        session_dict = session.model_dump() if hasattr(session, "model_dump") else session.dict()
        with open(session_file, "w") as f:
            json.dump(session_dict, f, indent=2, default=str)
        print(f"Saved session to: {session_file}")

        # Return the final response (last turn) as output
        final_response = responses[-1] if responses else ""
        
        return {
            "output": final_response,
            "trajectory": session,
            "all_responses": responses,  # Include all responses for reference
        }

    evaluators = [
        HelpfulnessEvaluator(),
        ToolSelectionAccuracyEvaluator(),
        GoalSuccessRateEvaluator(),
    ]

    # Run multi-turn evaluations
    print("\n--- Multi-Turn Cases ---")
    multi_turn_experiment = Experiment[list[str], str](cases=multi_turn_cases, evaluators=evaluators)
    multi_turn_reports = multi_turn_experiment.run_evaluations(multi_turn_task_function)

    print("\n" + "=" * 60)
    print("MULTI-TURN EVALUATION RESULTS")
    print("=" * 60)
    for report in multi_turn_reports:
        print(f"\nEvaluator: {report.evaluator_name}")
        print(f"Overall Score: {report.overall_score:.2f}")
        pass_count = sum(report.test_passes)
        total = len(report.test_passes)
        print(f"Pass Rate: {pass_count}/{total} ({pass_count/total*100:.0f}%)")
        print("\nTest Cases:")
        for i, (case, score, passed, reason) in enumerate(
            zip(report.cases, report.scores, report.test_passes, report.reasons)
        ):
            status = "PASS" if passed else "FAIL"
            case_name = case.get("name", f"case-{i}")
            case_input = case.get("input", [])
            turns = len(case_input) if isinstance(case_input, list) else 1
            print(f"  [{status}] {case_name} ({turns} turns): {score:.2f}")
            print(f"         Reason: {reason[:100]}...")

    # Flatten all reports
    # all_reports = single_turn_reports + multi_turn_reports
    all_reports = multi_turn_reports
    flatten_report = EvaluationReport.flatten(all_reports)
    flatten_report.run_display(include_actual_output=True)


# =============================================================================
# Main
# =============================================================================

def main():
    print("=" * 60)
    print("LANGCHAIN SIMPLE AGENT EVALUATION - Openinference")
    print("=" * 60)
    print("\nSingle ReAct agent with diverse tools")
    print("Tools: calculate, convert_units, get_current_datetime,")
    print("       calculate_date_difference, add_days_to_date,")
    print("       count_words, reverse_string, extract_numbers,")
    print("       lookup_country_capital, lookup_element")
    print("\nTest cases cover:")
    print("  - Single-turn: tool calls, errors, no-tool-needed")
    print("  - Multi-turn: follow-up queries, context-dependent, error recovery")
    print("  - Multi-turn sessions share the same session_id across turns")

    run_evaluation()


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant