diff --git a/.gitignore b/.gitignore index 267a7d6..afe1435 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ build .venv .mypy_cache .ruff_cache +.claude *.bak .vscode dist @@ -15,4 +16,4 @@ dataset_files report_files .venv *.DS_Store* -uv.lock +uv.lock \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a3b3b23..6cdde6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ select = [ [tool.ruff.lint.per-file-ignores] "src/strands_evals/evaluators/prompt_templates/*" = ["E501"] # line-length +"src/strands_evals/detectors/prompt_templates/*" = ["E501"] # line-length "src/strands_evals/generators/prompt_template/*" = ["E501"] # line-length [tool.mypy] diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index f5c600c..db11ca7 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,4 @@ -from . import evaluators, extractors, generators, providers, simulation, telemetry, types +from . import detectors, evaluators, extractors, generators, providers, simulation, telemetry, types from .case import Case from .experiment import Experiment from .simulation import ActorSimulator, UserSimulator @@ -7,6 +7,7 @@ __all__ = [ "Experiment", "Case", + "detectors", "evaluators", "extractors", "providers", diff --git a/src/strands_evals/detectors/__init__.py b/src/strands_evals/detectors/__init__.py new file mode 100644 index 0000000..9657130 --- /dev/null +++ b/src/strands_evals/detectors/__init__.py @@ -0,0 +1,32 @@ +"""Detectors for analyzing agent execution traces. + +Detectors answer "why did my agent behave this way?" by analyzing Session +traces for failures, summarizing execution, extracting user requests, and +performing root cause analysis. +""" + +from ..types.detector import ( + ConfidenceLevel, + DiagnosisResult, + FailureDetectionStructuredOutput, + FailureItem, + FailureOutput, + RCAItem, + RCAOutput, + RCAStructuredOutput, +) +from .failure_detector import detect_failures + +__all__ = [ + # Core detectors + "detect_failures", + # Types + "ConfidenceLevel", + "DiagnosisResult", + "FailureOutput", + "FailureItem", + "FailureDetectionStructuredOutput", + "RCAOutput", + "RCAItem", + "RCAStructuredOutput", +] diff --git a/src/strands_evals/detectors/chunking.py b/src/strands_evals/detectors/chunking.py new file mode 100644 index 0000000..9821138 --- /dev/null +++ b/src/strands_evals/detectors/chunking.py @@ -0,0 +1,124 @@ +"""Shared utilities for context window management and span chunking. + +Used by detect_failures, summarize_execution, and analyze_root_cause +to handle sessions that exceed LLM context limits. + +Ported from AgentCoreLens failure_detector._split_spans_by_tokens() and +related helpers, replacing litellm.token_counter() with a conservative +character-based estimate. +""" + +import json +import logging + +from ..types.detector import ConfidenceLevel, FailureItem +from ..types.trace import SpanUnion + +logger = logging.getLogger(__name__) + +CHARS_PER_TOKEN = 4 +CONTEXT_SAFETY_MARGIN = 0.75 +DEFAULT_MAX_INPUT_TOKENS = 128_000 +CHUNK_OVERLAP_SPANS = 2 +MIN_CHUNK_SIZE = 5 +_CONFIDENCE_RANK: dict[ConfidenceLevel, int] = {"low": 0, "medium": 1, "high": 2} + + +def estimate_tokens(text: str) -> int: + """Conservative token estimate from text length.""" + return len(text) // CHARS_PER_TOKEN + + +def would_exceed_context(prompt_text: str, max_input_tokens: int = DEFAULT_MAX_INPUT_TOKENS) -> bool: + """Pre-flight check: would this prompt likely exceed context? + + If wrong (false negative), the model returns a context-exceeded error, + which triggers the chunking fallback in each detector. + """ + return estimate_tokens(prompt_text) > int(max_input_tokens * CONTEXT_SAFETY_MARGIN) + + +def split_spans_by_tokens( + spans: list[SpanUnion], + max_tokens: int = DEFAULT_MAX_INPUT_TOKENS, + overlap_spans: int = CHUNK_OVERLAP_SPANS, +) -> list[list[SpanUnion]]: + """Split spans into chunks fitting within token limits. + + Adjacent chunks share ``overlap_spans`` spans for context continuity. + Ported from Lens ``failure_detector._split_spans_by_tokens()``. + + Args: + spans: Flat list of span objects to chunk. + max_tokens: Maximum model input tokens. + overlap_spans: Number of spans shared between adjacent chunks. + + Returns: + List of span chunks. Each chunk fits within the effective token limit. + """ + effective_limit = int(max_tokens * CONTEXT_SAFETY_MARGIN) + chunks: list[list[SpanUnion]] = [] + current_chunk: list[SpanUnion] = [] + current_tokens = 0 + + for span in spans: + span_tokens = estimate_tokens(json.dumps(span.model_dump(), default=str)) + if current_tokens + span_tokens > effective_limit and len(current_chunk) >= MIN_CHUNK_SIZE: + chunks.append(current_chunk) + overlap = current_chunk[-overlap_spans:] if overlap_spans > 0 else [] + current_chunk = list(overlap) + current_tokens = sum(estimate_tokens(json.dumps(s.model_dump(), default=str)) for s in overlap) + current_chunk.append(span) + current_tokens += span_tokens + + if current_chunk: + chunks.append(current_chunk) + + logger.info( + "Split %d spans into %d chunks (max_tokens=%d, overlap=%d)", + len(spans), + len(chunks), + max_tokens, + overlap_spans, + ) + return chunks + + +def merge_chunk_failures(chunk_results: list[list[FailureItem]]) -> list[FailureItem]: + """Merge failures from overlapping chunks, deduplicating by span_id. + + Keeps highest confidence per category when the same span_id appears + in multiple chunks. Ported from Lens ``_merge_chunk_failures()``. + + Args: + chunk_results: List of failure lists, one per chunk. + + Returns: + Deduplicated list of FailureItems. + """ + seen: dict[str, FailureItem] = {} + for chunk in chunk_results: + for failure in chunk: + if failure.span_id not in seen: + # Deep copy to avoid mutating originals + seen[failure.span_id] = FailureItem( + span_id=failure.span_id, + category=list(failure.category), + confidence=list(failure.confidence), + evidence=list(failure.evidence), + ) + else: + existing = seen[failure.span_id] + for i, cat in enumerate(failure.category): + if cat in existing.category: + idx = existing.category.index(cat) + if _CONFIDENCE_RANK.get(failure.confidence[i], 0) > _CONFIDENCE_RANK.get( + existing.confidence[idx], 0 + ): + existing.confidence[idx] = failure.confidence[i] + existing.evidence[idx] = failure.evidence[i] + else: + existing.category.append(cat) + existing.confidence.append(failure.confidence[i]) + existing.evidence.append(failure.evidence[i]) + return list(seen.values()) diff --git a/src/strands_evals/detectors/failure_detector.py b/src/strands_evals/detectors/failure_detector.py new file mode 100644 index 0000000..cc2851d --- /dev/null +++ b/src/strands_evals/detectors/failure_detector.py @@ -0,0 +1,172 @@ +"""Failure detection for agent execution sessions. + +Identifies semantic failures (hallucinations, task errors, tool misuse, etc.) +in Session traces using LLM-based analysis with automatic chunking fallback +for sessions exceeding context limits. + +Ported from AgentCoreLens tools/failure_detector.py. +""" + +import json +import logging + +from strands import Agent +from strands.models.model import Model +from strands.types.exceptions import ContextWindowOverflowException +from typing_extensions import Union, cast + +from ..types.detector import ConfidenceLevel, FailureDetectionStructuredOutput, FailureItem, FailureOutput +from ..types.trace import Session, SpanUnion +from .chunking import merge_chunk_failures, split_spans_by_tokens, would_exceed_context +from .prompt_templates.failure_detection import get_template + +logger = logging.getLogger(__name__) + +DEFAULT_DETECTOR_MODEL = "us.anthropic.claude-haiku-4-5-20251001-v1:0" +CONFIDENCE_ORDER: dict[ConfidenceLevel, int] = {"low": 0, "medium": 1, "high": 2} + + +def detect_failures( + session: Session, + *, + confidence_threshold: ConfidenceLevel = "low", + model: Union[Model, str, None] = None, +) -> FailureOutput: + """Detect semantic failures in an agent execution session. + + Args: + session: The Session object to analyze. + confidence_threshold: Minimum confidence level ("low", "medium", "high") + to include a failure. Defaults to "low" (include all). + model: Any Strands model provider. None uses default Haiku. + + Returns: + FailureOutput with list of FailureItems, each with span_id, category, + confidence, and evidence. + """ + effective_model = model if model is not None else DEFAULT_DETECTOR_MODEL + template = get_template("v0") + session_json = _serialize_session(session) + user_prompt = template.build_prompt(session_json=session_json) + + if would_exceed_context(user_prompt): + raw = _detect_chunked(session, effective_model, template) + else: + try: + raw = _detect_direct(user_prompt, effective_model, template) + except Exception as e: + if _is_context_exceeded(e): + logger.warning("Context exceeded despite pre-flight check, falling back to chunking") + raw = _detect_chunked(session, effective_model, template) + else: + raise + + threshold_rank = CONFIDENCE_ORDER[confidence_threshold] + filtered = [f for f in raw if _max_confidence_rank(f) >= threshold_rank] + return FailureOutput(session_id=session.session_id, failures=filtered) + + +def _detect_direct(user_prompt: str, model: Union[Model, str], template: object) -> list[FailureItem]: + """Attempt direct LLM detection on the full session.""" + agent = Agent(model=model, system_prompt=template.SYSTEM_PROMPT, callback_handler=None) + result = agent(user_prompt, structured_output_model=FailureDetectionStructuredOutput) + return _parse_structured_result(cast(FailureDetectionStructuredOutput, result.structured_output)) + + +def _detect_chunked( + session: Session, + model: Union[Model, str], + template: object, +) -> list[FailureItem]: + """Chunk session and detect failures per chunk, then merge.""" + spans = _flatten_traces_to_spans(session.traces) + chunks = split_spans_by_tokens(spans) + + logger.info("Chunked detection: %d spans -> %d chunks", len(spans), len(chunks)) + + chunk_results: list[list[FailureItem]] = [] + for i, chunk_spans in enumerate(chunks): + try: + chunk_json = _serialize_spans(chunk_spans, session.session_id) + user_prompt = template.build_prompt(session_json=chunk_json) + agent = Agent(model=model, system_prompt=template.SYSTEM_PROMPT, callback_handler=None) + result = agent(user_prompt, structured_output_model=FailureDetectionStructuredOutput) + chunk_results.append( + _parse_structured_result(cast(FailureDetectionStructuredOutput, result.structured_output)) + ) + logger.info("Chunk %d/%d: processed %d spans", i + 1, len(chunks), len(chunk_spans)) + except Exception as e: + if _is_context_exceeded(e): + logger.warning("Chunk %d/%d still exceeds context, skipping", i + 1, len(chunks)) + else: + raise + + return merge_chunk_failures(chunk_results) + + +def _parse_structured_result(output: FailureDetectionStructuredOutput) -> list[FailureItem]: + """Convert LLM structured output to list[FailureItem].""" + items = [] + for err in output.errors: + # Validate element-wise correspondence + if not (len(err.category) == len(err.evidence) == len(err.confidence)): + logger.warning( + "Mismatched array lengths for location %s: category=%d, evidence=%d, confidence=%d. Skipping.", + err.location, + len(err.category), + len(err.evidence), + len(err.confidence), + ) + continue + + items.append( + FailureItem( + span_id=err.location, + category=list(err.category), + confidence=list(err.confidence), + evidence=list(err.evidence), + ) + ) + return items + + +def _max_confidence_rank(item: FailureItem) -> int: + """Return the maximum confidence rank across all failure modes.""" + if not item.confidence: + return -1 + return max(CONFIDENCE_ORDER.get(c, 0) for c in item.confidence) + + +def _is_context_exceeded(exception: Exception) -> bool: + """Check if the exception indicates a context window overflow. + + Strands Agent raises ContextWindowOverflowException for context overflows. + Also checks error message strings as a fallback for providers that raise + generic errors (e.g., ValidationException from Bedrock). + """ + if isinstance(exception, ContextWindowOverflowException): + return True + msg = str(exception).lower() + return any(p in msg for p in ["context", "too long", "max_tokens", "input_limit"]) + + +def _flatten_traces_to_spans(traces: list) -> list[SpanUnion]: + """Flatten all traces into a single span list.""" + return [span for trace in traces for span in trace.spans] + + +def _serialize_session(session: Session) -> str: + """Serialize a full Session to JSON for the prompt.""" + return json.dumps(session.model_dump(), indent=2, default=str) + + +def _serialize_spans(spans: list[SpanUnion], session_id: str) -> str: + """Serialize a list of spans as a minimal session JSON for chunk prompts.""" + return json.dumps( + { + "session_id": session_id, + "traces": [{"spans": [s.model_dump() for s in spans]}], + }, + indent=2, + default=str, + ) diff --git a/src/strands_evals/detectors/prompt_templates/__init__.py b/src/strands_evals/detectors/prompt_templates/__init__.py new file mode 100644 index 0000000..cadb738 --- /dev/null +++ b/src/strands_evals/detectors/prompt_templates/__init__.py @@ -0,0 +1,5 @@ +"""Prompt templates for detectors. + +Each subdirectory contains versioned prompt modules following the same +pattern as evaluators/prompt_templates/. +""" diff --git a/src/strands_evals/detectors/prompt_templates/failure_detection/__init__.py b/src/strands_evals/detectors/prompt_templates/failure_detection/__init__.py new file mode 100644 index 0000000..13b8796 --- /dev/null +++ b/src/strands_evals/detectors/prompt_templates/failure_detection/__init__.py @@ -0,0 +1,13 @@ +from . import failure_detection_v0 + +VERSIONS = { + "v0": failure_detection_v0, +} + +DEFAULT_VERSION = "v0" + + +def get_template(version: str = DEFAULT_VERSION): + if version not in VERSIONS: + raise ValueError(f"Unknown version: {version}. Available: {list(VERSIONS.keys())}") + return VERSIONS[version] diff --git a/src/strands_evals/detectors/prompt_templates/failure_detection/failure_detection_v0.py b/src/strands_evals/detectors/prompt_templates/failure_detection/failure_detection_v0.py new file mode 100644 index 0000000..0a1b785 --- /dev/null +++ b/src/strands_evals/detectors/prompt_templates/failure_detection/failure_detection_v0.py @@ -0,0 +1,338 @@ +"""Failure detection prompt template v0. + +Ported verbatim from AgentCoreLens failure_detection.j2. +Only the rendering mechanism changed (Jinja2 -> Python f-strings). +""" + +SYSTEM_PROMPT = """\ +# LLM Annotation Prompt for Agent Session Data + +## System Role and Task Overview + +You are an expert AI systems evaluator tasked with annotating failure patterns \ +in AI agent session data. Your goal is to identify and categorize errors that \ +occur during agent interactions, following a systematic approach to failure \ +detection and classification. + +### Task Overview + +You will be provided with a session containing a list of traces. Each trace \ +represents a complete interaction cycle and contains spans representing atomic \ +operations (agent invocations, tool calls, LLM inference). Your task is to: + +1. **Analyze each span** in the session for potential failures +2. **Categorize failures** using the predefined taxonomy +3. **Skip spans** where no errors are identified +4. **Output annotations** in the specified JSON format""" + +DEFAULT_ANNOTATION_GUIDELINES = """\ +### Core Principles +- **Be consistent**: Apply the same standards across all sessions +- **Be critical**: Don't assume the agent's actions are correct +- **Be precise**: Use the most specific category available +- **Be exhaustive**: Record every distinct failure you observe +- **Be objective**: Focus on observable evidence, not inferred causes +- **Use "other"**: If uncertain about category, use "other" with explanation + +### Special Attention: execution-error +**Don't overlook obvious execution failures.** While maintaining thorough \ +coverage of all error categories, be particularly vigilant for explicit \ +"execution-error-*" category failures across all `span_type`s, as these \ +manifested failures can sometimes be dismissed despite being clearly \ +observable in system outputs and feedback. + +### Span-by-Span Analysis Process + +1. **Read the complete session** to understand context, goals, and agent behavior +2. **For each span** (within each trace): + - If no failure detected -> Skip this span (do not annotate) + - If failure detected -> Assign appropriate category and provide evidence""" + +CATEGORIES = """\ +Use the specific category names (not parent categories) in your annotations. \ +The details for each category are described in the following table organized \ +by parent categories: + +|Parent Category|Category|Category Definition|Examples| +|---|---|---|---| +|**execution-error**|execution-error-category-authentication|Failed access attempts due to credential or permission issues\ +

\u2022 HTTP 401/403 errors
\u2022 Invalid/expired tokens
\u2022 Missing credentials|\ +\u2022 "Invalid API key provided"
\u2022 "Access token expired"
\u2022 "Unauthorized access to Spotify API"| +||execution-error-category-resource-not-found|Requested resource does not exist at specified location\ +

\u2022 HTTP 404 errors
\u2022 Missing file/endpoint
\u2022 Invalid identifier|\ +\u2022 "Playlist ID not found"
"Endpoint /api/v2/users does not exist"
\u2022 "Image file missing from specified path"| +||execution-error-category-service-errors|Upstream service failure or unavailability\ +

\u2022 HTTP 500 errors
\u2022 Service disruption
\u2022 System failures|\ +\u2022 "Internal server error in database connection"
\u2022 "AWS service temporarily unavailable"
\ +"Google API service disruption"| +||execution-error-category-rate-limiting|Request frequency exceeds allowed quota\ +

\u2022 HTTP 429 errors
\u2022 Quota exceeded messages
\u2022 Rate threshold warnings|\ +\u2022 "Too many requests to Twitter API"
\u2022 "Rate limit exceeded: try again in 60 seconds"
\ +\u2022 "Daily quota exceeded for Google Maps API"| +||execution-error-category-formatting|Failure to produce correctly structured output according to expected \ +format or syntax

\u2022 Syntax errors in structured data
\u2022 Missing required structural elements
\ +\u2022 Invalid format for specified type|\u2022 JSON missing closing bracket
\u2022 SQL query with incorrect \ +syntax
\u2022 HTML with unclosed tags| +||execution-error-category-timeout|Request or operation exceeded time constraints\ +

\u2022 Explicit timeout message
\u2022 Duration threshold exceeded
\u2022 Connection time limit|\ +\u2022 "Request timed out after 30 seconds"
\u2022 "Function execution exceeded time limit"
\ +\u2022 "Connection timeout to external API"| +||execution-error-category-resource-exhaustion|System resource capacity limits reached\ +

\u2022 Memory/disk/CPU limits
\u2022 Resource quota exceeded
\u2022 System capacity error|\ +\u2022 "Out of memory error during image processing"
\u2022 "Disk space full during file download"
\ +\u2022 "CPU quota exceeded"| +||execution-error-category-environment|Missing or incorrect system setup requirements\ +

\u2022 Missing environment variables
\u2022 Invalid configuration
\u2022 Setup requirement errors|\ +\u2022 "OPENAI_API_KEY not found in environment"
\u2022 "Missing write permissions for database"
\ +\u2022 "Required configuration file not present"| +||execution-error-category-tool-schema|Tool input/output does not match required schema\ +

\u2022 Parameter validation errors
\u2022 Missing required fields
\u2022 Invalid data types|\ +\u2022 "Expected parameter 'date' to be ISO format"
\u2022 "Required field 'user_id' missing"
\ +\u2022 "Invalid enum value for 'sort_order'"| +|**task-instruction**|task-instruction-category-non-compliance|Failure to follow explicit directives from the \ +user or system prompts/constraints

Not all datasets have this information|System: "Always authenticate \ +user before providing account information"
User: "What's my account balance?"
Agent: *immediately provides \ +balance without authentication*| +||task-instruction-category-problem-id|Agent correctly identifies the end goal but fails to determine the \ +appropriate path to achieve it. The agent applies incorrect reasoning about how to reach the goal, leading to \ +misaligned actions.|Goal: Fix SQL indentation issue
User: "There's an extra space appearing in WITH clause \ +indentation in L003.py"

Agent: *Opens and analyzes core/parser/segments.py*
"I'll help fix the \ +indentation issue. Looking at segments.py which handles SQL parsing..."
[Proceeds to analyze general parsing \ +logic]

[\\u2713 Goal: Fix indentation spacing issue
\\u2717 Path: Wrong file - should examine L003.py \ +which handles indentation rules
\\u2717 Reasoning: Agent incorrectly assumes indentation is handled in general \ +parser rather than specific rule file]| +|**incorrect-actions**|incorrect-actions-category-tool-selection|Using an incorrect but available tool when a \ +more appropriate tool exists in the agent's toolbox

\u2022 Tool exists in toolbox
\u2022 Wrong tool \ +selected for task
\u2022 More appropriate tool available
\u2022 Not about hallucinated tools|\ +\u2022 Using search_flights() when check_flight_status() exists and is appropriate
\u2022 Using \ +general_search() when specific domain_search() available
\u2022 Using text_analysis() when \ +numerical_calculator() needed| +||incorrect-actions-category-poor-information-retrieval|Incorrect or ineffective use of a correctly selected \ +tool

\u2022 Right tool, wrong query
\u2022 Irrelevant search terms
\u2022 Missing key search \ +parameters
\u2022 Overly broad/narrow queries|\u2022 Searching "airplane" when looking for specific flight \ +status
\u2022 Using too generic terms in domain-specific search
\u2022 Missing critical search parameters| +||incorrect-actions-category-clarification|Proceeding with action despite unclear/incomplete information\ +

\u2022 Missing critical info
\u2022 Ambiguous user input
\u2022 Assumptions made instead of asking|\ +\u2022 Booking flight without confirming date
\u2022 Processing refund without asking reason
\ +\u2022 Modifying order without confirming which one| +||incorrect-actions-category-inappropriate-info-request|Agent makes an information request that is inappropriate \ +for at least one of these reasons:

1. Tool-Available: Information should be retrieved through agent's \ +available tools/APIs
-> Is there an API/tool that should be used instead?

2.. Task-Irrelevant: \ +Information is not needed for any stage of the current task
-> Is the information unnecessary for the task?\ +


Differences compared to other categories:
\u2022 First-time request (if repeated, see Information \ +Seeking Repetition)
|\u2022 Agent to user: "What's the current exchange rate?" (should use currency_api \ +instead)
\u2022 Agent to user: "What's your passport number?" (when just browsing flights, not booking)
\ +\u2022 Agent to user: "What's the weather like in Paris?" (when agent has weather_api access)| +|**context-handling-error**|context-handling-error-category-context-handling-failures|Sudden and unwarranted \ +loss of recent interaction context while retaining older context, or complete and unexpected restart of \ +conversation state, losing all accumulated context|\u2022 Agent forgets user's last selection but remembers \ +initial request
\u2022 Agent asks for information provided in last 2-3 turns
\u2022 Agent reverts to \ +earlier conversation state
\u2022 Agent suddenly greets user as new in middle of booking
\u2022 Agent \ +loses all progress in multi-step process
\u2022 Agent restarts task from beginning| +|**hallucination**|hallucination-category-hall-capabilities|Agent claims or attempts to use tool features that \ +don't exist in the API specification|\u2022 "I'll use search_flights.filter_by_meal_preference()"
\u2022 \ +"The email_tool can automatically translate messages"
\u2022 "Let me use the AI image generator" (when no \ +such tool exists)| +||hallucination-category-hall-misunderstand|Agent misinterprets the meaning or structure of actual tool responses \ +while the response itself is valid.|\u2022 Interpreting error code as success
\u2022 Treating warning as \ +confirmation
\u2022 Misreading data structure| +||hallucination-category-hall-usage|Agent claims to have used tools when it hasn't|\u2022 "I've checked the \ +database" (without tool call)
\u2022 "I've verified with the API" (no verification done)
\u2022 "The \ +tool confirmed..." (no tool used)| +||hallucination-category-hall-history|Agent references conversations or user inputs that never occurred|\u2022 \ +"As you mentioned earlier..." (when not mentioned)
\u2022 "Based on your preference for..." (never stated)\ +
\u2022 "Following up on your request..." (no such request)| +||hallucination-category-hall-params|Agent uses parameter values that conflict with established previous \ +trajectory context|\u2022 User: "I want to fly to Paris"
[later] Agent: "Where would you like to fly to?"\ +
\u2022 User: "My account number is [REDACTED:BANK_ACCOUNT_NUMBER]"
[later] Agent: "Can you provide \ +your account number?"
\u2022 User: "I prefer morning flights"
[later] Agent: "Do you have a preference \ +for flight time?"| +||hallucination-category-fabricate-tool-outputs|Agent fabricates or makes up tool outputs that were not actually \ +returned|\u2022 Agent claims API returned data that was never received
\u2022 Agent invents success messages\ +
\u2022 Agent creates fake error responses| +|**repetitive-behavior**|repetitive-behavior-category-repetition-tool|Making identical API/tool calls multiple \ +times without justification

\u2022 Same tool, same parameters
\u2022 No user request ("provide me \ +with 2 more options") nor justification (the result of "check account balance" would be different before and \ +after making a withdrawal, so calling it twice could be justified)|\u2022 Calling check_flight_status(\ +flight_id="123") repeatedly
\u2022 Multiple identical search queries
\u2022 Repeating API \ +authentication calls| +||repetitive-behavior-category-repetition-info|Requesting information from user that was previously provided in \ +the conversation

\u2022 Asking for provided data
\u2022 Ignoring prior responses
\u2022 \ +Redundant user queries
\u2022 Forgetting stated preferences

The key distinction from Inappropriate \ +Information Request is that the information being requested:
- Would be appropriate if asked for the first \ +time
- Cannot be obtained through tools/APIs
- Is relevant to the current task
- Was already directly \ +stated by the user in the conversation|\u2022 "What's your destination?" (after already asked)
\u2022 \ +"Please provide your user ID" (already given)
\u2022 "What's your preferred date?" (previously stated)| +||repetitive-behavior-category-step-repetition|Repeating the same action steps or workflow stages without \ +progress or justification|\u2022 Agent repeats same verification step multiple times
\u2022 Cycling through \ +same process without advancing
\u2022 Re-executing completed steps| +|**orchestration-related-errors**|orchestration-related-errors-category-reasoning-mismatch|Disconnect between \ +agent's stated reasoning/plan and actual executed actions|\u2022 Planning to verify but skipping verification\ +
\u2022 Stating one approach but executing another
\u2022 Reasoning about safety but ignoring checks| +||orchestration-related-errors-category-goal-deviation|Agent diverges from original task objective or user \ +intent|\u2022 Booking flight becomes travel planning
\u2022 Simple query becomes complex analysis
\ +\u2022 Support request becomes sales opportunity| +||orchestration-related-errors-category-premature-termination|Task or interaction ended before completion of \ +necessary steps or objectives|\u2022 Ending before payment confirmation
\u2022 Missing final verifications\ +
\u2022 Incomplete data collection
\u2022 Skipped error handling| +||orchestration-related-errors-category-unaware-termination|Failure to recognize or properly handle task \ +completion or continuation criteria|\u2022 Continuing after task completion
\u2022 Missing natural end \ +points
\u2022 Ignoring completion signals
\u2022 Redundant actions| +|**llm-output**|llm-output-category-nonsensical|Exposing internal system details, implementation logic, or \ +debug information in user-facing responses; producing malformed, illogical, or endlessly looping tool calls or \ +JSON structures; or displaying redacted, placeholder, or special tokens|\u2022 "As per my system prompt..."
\ +\u2022 "According to my training..."
\u2022 "Error in function call_api() line 234"
\u2022 "Internal \ +state: processing_step_2"
\u2022 Showing API keys/endpoints
\u2022 {{"tool": {{"tool": {{"tool": \ +{{...}}}}}}}}
\u2022 Mixing user text and JSON
\u2022 Invalid tool call sequences
\u2022 "Hello \ +[MASK], how are you?"
\u2022 "Your balance is "| +|**configuration-mismatch**|configuration-mismatch-category-tool-definition|Tool functionality differs from its \ +declared purpose or capabilities
Disconnect between API requirements and agent's actual capabilities

\ +The root cause of the issue is in the set up|\u2022 Calculator tool declared as search engine
\u2022 Text \ +processor declared as numerical tool
\u2022 Simple lookup declared as complex analysis

-----------\ +

\u2022 POST body requirements for GET-only agent
\u2022 Complex JSON expected for simple calls
\ +\u2022 Binary data handling for text-only agent| +|**coding-use-case-specific-failure-types**|coding-use-case-specific-failure-types-category-edge-case-oversights|\ +Failure to handle non-standard inputs, boundary conditions, or exceptional scenarios in code generation or \ +modification|| +||coding-use-case-specific-failure-types-category-dependency-issues|Failures in handling code dependencies, \ +imports, or external library requirements||""" + +OUTPUT_FORMAT = """\ +Your output must be a valid JSON object following this exact structure: + +```json +{ + "errors": [ + { + "location": "the span ID where failure occurred", + "category": ["category name"], + "evidence": ["brief explanation of the failure"], + "confidence": ["low | medium | high"] + } + ] +} +```""" + +FORMAT_REQUIREMENTS = """\ +- Return valid JSON only, no additional text +- Use the exact field names: "location", "category", "evidence", "confidence" +- "category" must be an array of strings +- "evidence" must be an array of strings +- "confidence" must be an array where each element is one of: "low", "medium", "high" +- When multiple failure modes exist at the same location, arrays must maintain \ +element-wise correspondence: category[i] corresponds to evidence[i] and confidence[i] +- If no failures are detected, return `{"errors": []}`""" + +EXAMPLES = """\ +### Single Failure Mode Example +If an agent repeatedly calls the same API without justification: +```json +{ + "errors": [ + { + "location": "span_id_where_repetition_occurs", + "category": ["repetitive-behavior-category-repetition-tool"], + "evidence": ["Agent called check_flight_status API 3 times with identical parameters \ +without user request or justification"], + "confidence": ["high"] + } + ] +} +``` + +### Multiple Failure Modes at Same Location Example +If an agent fabricates tool output AND repeats the same incorrect action at the same span: +```json +{ + "errors": [ + { + "location": "span_id_with_multiple_failures", + "category": ["hallucination-category-hall-usage", "repetitive-behavior-category-repetition-tool"], + "evidence": ["Agent fabricated a successful API response claiming flight was booked", \ +"Agent repeated the same invalid booking attempt 3 times without addressing the underlying error"], + "confidence": ["high", "high"] + } + ] +} +``` + +**Note:** When multiple failure modes occur at the same location, maintain element-wise \ +correspondence across arrays: the first category corresponds to the first evidence and \ +first confidence, and so on.""" + +QUALITY_CHECKS = """\ +- All identified failures have clear evidence +- Categories accurately describe the failure type +- Confidence levels reflect certainty of the failure +- JSON is valid and follows the exact schema +- Location IDs match actual spans in the trajectory""" + +TASK_INSTRUCTIONS = """\ +Now analyze the session data provided between and . The session \ +contains traces, each with spans representing atomic operations. Return your annotation \ +in the exact JSON format specified above, without additional text or markdown formatting \ +like ```json""" + + +def build_prompt( + session_json: str, + annotation_guidelines: str | None = None, + category_descriptions: str | None = None, +) -> str: + """Build the user message with session data for failure analysis. + + Args: + session_json: Serialized session trace as JSON string. + annotation_guidelines: Custom guidelines. Uses DEFAULT_ANNOTATION_GUIDELINES if None. + category_descriptions: Custom taxonomy. Uses CATEGORIES if None. + """ + guidelines = annotation_guidelines or DEFAULT_ANNOTATION_GUIDELINES + categories = category_descriptions or CATEGORIES + + return f"""--- + +## Annotation Guidelines + +{guidelines} + +--- + +## Failure Category Definitions + +{categories} + +--- + +## Output Format + +{OUTPUT_FORMAT} + +### Format Requirements: + +{FORMAT_REQUIREMENTS} + +--- + +## Examples + +{EXAMPLES} + +--- + +## Quality Checks + +Ensure these are verified before submitting your annotation: + +{QUALITY_CHECKS} + +--- + +## Task Instructions + +{TASK_INSTRUCTIONS} + +--- + + +{session_json} +""" diff --git a/src/strands_evals/types/__init__.py b/src/strands_evals/types/__init__.py index 9c4d57b..14755f9 100644 --- a/src/strands_evals/types/__init__.py +++ b/src/strands_evals/types/__init__.py @@ -1,3 +1,13 @@ +from .detector import ( + DiagnosisResult, + FailureDetectionStructuredOutput, + FailureItem, + FailureOutput, + RCAItem, + RCAOutput, + RCAStructuredOutput, + SummaryOutput, +) from .evaluation import EnvironmentState, EvaluationData, EvaluationOutput, InputT, Interaction, OutputT, TaskOutput from .simulation import ActorProfile, ActorResponse @@ -11,4 +21,12 @@ "ActorResponse", "InputT", "OutputT", + "DiagnosisResult", + "FailureDetectionStructuredOutput", + "FailureItem", + "FailureOutput", + "RCAItem", + "RCAOutput", + "RCAStructuredOutput", + "SummaryOutput", ] diff --git a/src/strands_evals/types/detector.py b/src/strands_evals/types/detector.py new file mode 100644 index 0000000..e4379f6 --- /dev/null +++ b/src/strands_evals/types/detector.py @@ -0,0 +1,159 @@ +"""Pydantic models for detectors. + +Includes both output types (FailureOutput, RCAOutput, etc.) and +LLM structured output schemas (FailureDetectionStructuredOutput, etc.). +""" + +from typing import Literal + +from pydantic import BaseModel, Field + +# Confidence levels used across detectors +ConfidenceLevel = Literal["low", "medium", "high"] + +# --- Output types (what users consume) --- + + +class FailureItem(BaseModel): + """A single detected failure.""" + + span_id: str = Field(description="Span where failure occurred") + category: list[str] = Field(description="Failure classifications") + confidence: list[ConfidenceLevel] = Field(description="Confidence per category") + evidence: list[str] = Field(description="Evidence per category") + + +class FailureOutput(BaseModel): + """Output from detect_failures().""" + + session_id: str + failures: list[FailureItem] = Field(default_factory=list) + + +class RCAItem(BaseModel): + """A single root cause finding.""" + + failure_span_id: str = Field(description="The failure span this explains") + location: str = Field(description="Span where root cause originated") + causality: str = Field(description="PRIMARY_FAILURE | SECONDARY_FAILURE | TERTIARY_FAILURE") + propagation_impact: list[str] = Field(default_factory=list) + root_cause_explanation: str + fix_type: str = Field(description="SYSTEM_PROMPT_FIX | TOOL_DESCRIPTION_FIX | OTHERS") + fix_recommendation: str + + +class RCAOutput(BaseModel): + """Output from analyze_root_cause().""" + + root_causes: list[RCAItem] = Field(default_factory=list) + + +class DiagnosisResult(BaseModel): + """Output from diagnose_session(). + + Contains failures detected and their root causes. + """ + + session_id: str + failures: list[FailureItem] = Field(default_factory=list) + root_causes: list[RCAItem] = Field(default_factory=list) + + +# --- LLM structured output schemas (used with Agent structured_output_model) --- + + +class FailureError(BaseModel): + """LLM output schema: single failure entry.""" + + location: str + category: list[str] + confidence: list[Literal["low", "medium", "high"]] + evidence: list[str] + + +class FailureDetectionStructuredOutput(BaseModel): + """LLM output schema: failure detection result.""" + + errors: list[FailureError] + + +class FixRecommendation(BaseModel): + """LLM output schema: fix recommendation.""" + + model_config = {"populate_by_name": True} + + fix_type: Literal["SYSTEM_PROMPT_FIX", "TOOL_DESCRIPTION_FIX", "OTHERS"] = Field( + ..., alias="Fix Type", description="Type of fix needed" + ) + recommendation: str = Field( + ..., + alias="Recommendation", + description="Brief, actionable fix suggestion (1-2 sentences)", + ) + + +class RootCauseItem(BaseModel): + """LLM output schema: single root cause entry.""" + + model_config = {"populate_by_name": True} + + failure_span_id: str = Field( + ..., + alias="Failure Span ID", + description="The span_id from execution_failures that this root cause addresses (1:1 mapping)", + ) + location: str = Field( + ..., + alias="Location", + description="Exact span_id where the root cause failure occurred", + ) + failure_causality: Literal["PRIMARY_FAILURE", "SECONDARY_FAILURE", "TERTIARY_FAILURE", "UNCLEAR"] = Field( + ..., + alias="Failure Causality", + description="Causality classification of the failure", + ) + failure_propagation_impact: list[ + Literal[ + "TASK_TERMINATION", + "QUALITY_DEGRADATION", + "INCORRECT_PATH", + "STATE_CORRUPTION", + "NO_PROPAGATION", + "UNCLEAR", + ] + ] = Field( + ..., + alias="Failure Propagation Impact", + description="List of impact types on task execution", + ) + failure_detection_timing: Literal[ + "IMMEDIATELY_AT_OCCURRENCE", + "SEVERAL_STEPS_LATER", + "ONLY_AT_TASK_END", + "SILENT_UNDETECTED", + ] = Field( + ..., + alias="Failure Detection Timing", + description="When the failure was detected in the execution", + ) + completion_status: Literal["COMPLETE_SUCCESS", "PARTIAL_SUCCESS", "COMPLETE_FAILURE"] = Field( + ..., alias="Completion Status", description="Overall task completion status" + ) + root_cause_explanation: str = Field( + ..., + alias="Root Cause Explanation", + description="Concise explanation of the fundamental issue (2-3 sentences)", + ) + fix_recommendation: FixRecommendation = Field( + ..., + alias="Fix Recommendation", + description="Structured recommendation for addressing the root cause", + ) + + +class RCAStructuredOutput(BaseModel): + """LLM output schema: root cause analysis result.""" + + root_causes: list[RootCauseItem] = Field( + ..., description="List of all identified root causes in the execution trace" + ) diff --git a/src/strands_evals/types/trace.py b/src/strands_evals/types/trace.py index 9486ee0..8719a22 100644 --- a/src/strands_evals/types/trace.py +++ b/src/strands_evals/types/trace.py @@ -118,6 +118,7 @@ class AgentInvocationSpan(BaseSpan): user_prompt: str agent_response: str available_tools: list[ToolConfig] + system_prompt: str | None = None SpanUnion: TypeAlias = Union[InferenceSpan, ToolExecutionSpan, AgentInvocationSpan] diff --git a/tests/strands_evals/detectors/__init__.py b/tests/strands_evals/detectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strands_evals/detectors/test_chunking.py b/tests/strands_evals/detectors/test_chunking.py new file mode 100644 index 0000000..7dd37fd --- /dev/null +++ b/tests/strands_evals/detectors/test_chunking.py @@ -0,0 +1,216 @@ +"""Tests for chunking utilities.""" + +from datetime import datetime + +from strands_evals.detectors.chunking import ( + CHARS_PER_TOKEN, + DEFAULT_MAX_INPUT_TOKENS, + estimate_tokens, + merge_chunk_failures, + split_spans_by_tokens, + would_exceed_context, +) +from strands_evals.types.detector import FailureItem +from strands_evals.types.trace import ( + InferenceSpan, + SpanInfo, + TextContent, + ToolCall, + ToolExecutionSpan, + ToolResult, + UserMessage, +) + + +def _make_span_info(span_id: str) -> SpanInfo: + now = datetime.now() + return SpanInfo( + session_id="sess_1", + span_id=span_id, + trace_id="trace_1", + start_time=now, + end_time=now, + ) + + +def _make_tool_span(span_id: str, content_size: int = 100) -> ToolExecutionSpan: + """Create a ToolExecutionSpan with controllable content size.""" + return ToolExecutionSpan( + span_info=_make_span_info(span_id), + tool_call=ToolCall(name="test_tool", arguments={"data": "x" * content_size}), + tool_result=ToolResult(content="y" * content_size), + ) + + +def _make_inference_span(span_id: str) -> InferenceSpan: + return InferenceSpan( + span_info=_make_span_info(span_id), + messages=[UserMessage(content=[TextContent(text="Hello")])], + ) + + +# --- estimate_tokens --- + + +def test_estimate_tokens_empty(): + assert estimate_tokens("") == 0 + + +def test_estimate_tokens_short(): + text = "Hello world" + assert estimate_tokens(text) == len(text) // CHARS_PER_TOKEN + + +def test_estimate_tokens_long(): + text = "a" * 1000 + assert estimate_tokens(text) == 250 + + +# --- would_exceed_context --- + + +def test_would_exceed_context_small(): + assert not would_exceed_context("Hello world") + + +def test_would_exceed_context_large(): + # Create text that exceeds 128K * 0.75 = 96K tokens = 384K chars + big_text = "x" * 400_000 + assert would_exceed_context(big_text) + + +def test_would_exceed_context_custom_limit(): + # 100 chars = 25 tokens; limit 30 * 0.75 = 22.5 -> exceeds + assert would_exceed_context("x" * 100, max_input_tokens=30) + + +def test_would_exceed_context_just_under(): + # effective limit = 100 * 0.75 = 75 tokens = 300 chars + assert not would_exceed_context("x" * 299, max_input_tokens=100) + + +# --- split_spans_by_tokens --- + + +def test_split_spans_single_chunk(): + """Small spans should stay in one chunk.""" + spans = [_make_tool_span(f"span_{i}", content_size=10) for i in range(3)] + chunks = split_spans_by_tokens(spans, max_tokens=DEFAULT_MAX_INPUT_TOKENS) + assert len(chunks) == 1 + assert len(chunks[0]) == 3 + + +def test_split_spans_multiple_chunks(): + """Large spans should be split into multiple chunks.""" + # Each span ~10K chars = ~2500 tokens. With limit of 5000 tokens (effective 3750), + # we need multiple chunks. But MIN_CHUNK_SIZE=5, so we need enough spans. + spans = [_make_tool_span(f"span_{i}", content_size=2000) for i in range(10)] + chunks = split_spans_by_tokens(spans, max_tokens=5000, overlap_spans=1) + assert len(chunks) > 1 + # All spans should be present across chunks + all_span_ids = set() + for chunk in chunks: + for span in chunk: + all_span_ids.add(span.span_info.span_id) + assert len(all_span_ids) == 10 + + +def test_split_spans_overlap(): + """Adjacent chunks should share overlap spans.""" + spans = [_make_tool_span(f"span_{i}", content_size=2000) for i in range(10)] + chunks = split_spans_by_tokens(spans, max_tokens=5000, overlap_spans=2) + if len(chunks) > 1: + # Last 2 spans of chunk 0 should be first 2 of chunk 1 + last_of_first = [s.span_info.span_id for s in chunks[0][-2:]] + first_of_second = [s.span_info.span_id for s in chunks[1][:2]] + assert last_of_first == first_of_second + + +def test_split_spans_no_overlap(): + """With overlap=0, chunks should not share spans.""" + spans = [_make_tool_span(f"span_{i}", content_size=2000) for i in range(10)] + chunks = split_spans_by_tokens(spans, max_tokens=5000, overlap_spans=0) + if len(chunks) > 1: + first_ids = {s.span_info.span_id for s in chunks[0]} + second_ids = {s.span_info.span_id for s in chunks[1]} + assert first_ids.isdisjoint(second_ids) + + +def test_split_spans_empty(): + chunks = split_spans_by_tokens([], max_tokens=1000) + assert chunks == [] + + +# --- merge_chunk_failures --- + + +def test_merge_no_overlap(): + """Failures from different spans should be preserved.""" + chunk1 = [ + FailureItem(span_id="span_1", category=["error_a"], confidence=["high"], evidence=["ev_a"]), + ] + chunk2 = [ + FailureItem(span_id="span_2", category=["error_b"], confidence=["high"], evidence=["ev_b"]), + ] + merged = merge_chunk_failures([chunk1, chunk2]) + assert len(merged) == 2 + + +def test_merge_same_span_different_category(): + """Same span, different categories should be combined.""" + chunk1 = [ + FailureItem(span_id="span_1", category=["error_a"], confidence=["high"], evidence=["ev_a"]), + ] + chunk2 = [ + FailureItem(span_id="span_1", category=["error_b"], confidence=["high"], evidence=["ev_b"]), + ] + merged = merge_chunk_failures([chunk1, chunk2]) + assert len(merged) == 1 + assert len(merged[0].category) == 2 + assert "error_a" in merged[0].category + assert "error_b" in merged[0].category + + +def test_merge_same_span_same_category_keeps_highest(): + """Same span+category: keep the higher confidence.""" + chunk1 = [ + FailureItem(span_id="span_1", category=["error_a"], confidence=["low"], evidence=["weak"]), + ] + chunk2 = [ + FailureItem(span_id="span_1", category=["error_a"], confidence=["high"], evidence=["strong"]), + ] + merged = merge_chunk_failures([chunk1, chunk2]) + assert len(merged) == 1 + assert merged[0].confidence[0] == "high" + assert merged[0].evidence[0] == "strong" + + +def test_merge_same_span_same_category_no_downgrade(): + """Same span+category: lower confidence should not replace higher.""" + chunk1 = [ + FailureItem(span_id="span_1", category=["error_a"], confidence=["high"], evidence=["strong"]), + ] + chunk2 = [ + FailureItem(span_id="span_1", category=["error_a"], confidence=["low"], evidence=["weak"]), + ] + merged = merge_chunk_failures([chunk1, chunk2]) + assert merged[0].confidence[0] == "high" + assert merged[0].evidence[0] == "strong" + + +def test_merge_empty(): + assert merge_chunk_failures([]) == [] + assert merge_chunk_failures([[]]) == [] + + +def test_merge_does_not_mutate_originals(): + """Merging should not modify the input FailureItems.""" + item = FailureItem(span_id="span_1", category=["error_a"], confidence=["high"], evidence=["ev_a"]) + chunk1 = [item] + chunk2 = [ + FailureItem(span_id="span_1", category=["error_b"], confidence=["high"], evidence=["ev_b"]), + ] + merge_chunk_failures([chunk1, chunk2]) + # Original item should be unchanged + assert len(item.category) == 1 + assert item.category[0] == "error_a" diff --git a/tests/strands_evals/detectors/test_failure_detector.py b/tests/strands_evals/detectors/test_failure_detector.py new file mode 100644 index 0000000..637572b --- /dev/null +++ b/tests/strands_evals/detectors/test_failure_detector.py @@ -0,0 +1,282 @@ +"""Tests for failure detector.""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from strands_evals.detectors.failure_detector import ( + CONFIDENCE_ORDER, + _is_context_exceeded, + _max_confidence_rank, + _parse_structured_result, + _serialize_session, + _serialize_spans, + detect_failures, +) +from strands_evals.types.detector import FailureDetectionStructuredOutput, FailureError, FailureItem, FailureOutput +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + Session, + SpanInfo, + TextContent, + ToolCall, + ToolConfig, + ToolExecutionSpan, + ToolResult, + Trace, + UserMessage, +) + + +def _span_info(span_id: str = "span_1") -> SpanInfo: + now = datetime.now() + return SpanInfo(session_id="sess_1", span_id=span_id, trace_id="trace_1", start_time=now, end_time=now) + + +def _make_session(spans=None) -> Session: + if spans is None: + spans = [ + AgentInvocationSpan( + span_info=_span_info("span_1"), + user_prompt="Hello", + agent_response="Hi there", + available_tools=[ToolConfig(name="search")], + ), + InferenceSpan( + span_info=_span_info("span_2"), + messages=[UserMessage(content=[TextContent(text="Hello")])], + ), + ] + return Session( + session_id="sess_1", + traces=[Trace(trace_id="trace_1", session_id="sess_1", spans=spans)], + ) + + +# --- _parse_structured_result --- + + +def test_parse_structured_result_basic(): + output = FailureDetectionStructuredOutput( + errors=[ + FailureError( + location="span_1", + category=["hallucination-category-hall-usage"], + confidence=["high"], + evidence=["Agent claimed to use tool without calling it"], + ) + ] + ) + result = _parse_structured_result(output) + assert len(result) == 1 + assert result[0].span_id == "span_1" + assert result[0].category == ["hallucination-category-hall-usage"] + assert result[0].confidence == ["high"] + assert result[0].evidence == ["Agent claimed to use tool without calling it"] + + +def test_parse_structured_result_multiple_modes(): + output = FailureDetectionStructuredOutput( + errors=[ + FailureError( + location="span_1", + category=["error_a", "error_b"], + confidence=["low", "high"], + evidence=["ev_a", "ev_b"], + ) + ] + ) + result = _parse_structured_result(output) + assert len(result) == 1 + assert len(result[0].category) == 2 + assert result[0].confidence == ["low", "high"] + + +def test_parse_structured_result_mismatched_arrays(): + output = FailureDetectionStructuredOutput( + errors=[ + FailureError( + location="span_1", + category=["error_a", "error_b"], + confidence=["high"], + evidence=["only one"], + ) + ] + ) + result = _parse_structured_result(output) + assert len(result) == 0 # skipped due to mismatch + + +def test_parse_structured_result_empty(): + output = FailureDetectionStructuredOutput(errors=[]) + result = _parse_structured_result(output) + assert result == [] + + +# --- _max_confidence_rank --- + + +def test_max_confidence_rank(): + item = FailureItem(span_id="s", category=["a", "b"], confidence=["low", "high"], evidence=["x", "y"]) + assert _max_confidence_rank(item) == CONFIDENCE_ORDER["high"] + + +def test_max_confidence_rank_empty(): + item = FailureItem(span_id="s", category=[], confidence=[], evidence=[]) + assert _max_confidence_rank(item) == -1 + + +# --- _is_context_exceeded --- + + +def test_is_context_exceeded_strands_exception(): + from strands.types.exceptions import ContextWindowOverflowException + + assert _is_context_exceeded(ContextWindowOverflowException("too big")) + + +def test_is_context_exceeded_string_match(): + assert _is_context_exceeded(Exception("The context window is exceeded")) + assert _is_context_exceeded(Exception("Input too long for model")) + assert _is_context_exceeded(Exception("max_tokens limit reached")) + + +def test_is_context_exceeded_unrelated(): + assert not _is_context_exceeded(Exception("Something else went wrong")) + assert not _is_context_exceeded(ValueError("bad value")) + + +# --- _serialize_session / _serialize_spans --- + + +def test_serialize_session(): + session = _make_session() + result = _serialize_session(session) + assert "sess_1" in result + assert "span_1" in result + + +def test_serialize_spans(): + span = ToolExecutionSpan( + span_info=_span_info("span_10"), + tool_call=ToolCall(name="test", arguments={}), + tool_result=ToolResult(content="ok"), + ) + result = _serialize_spans([span], "sess_1") + assert "sess_1" in result + assert "span_10" in result + + +# --- detect_failures (with mocked Agent) --- + + +@patch("strands_evals.detectors.failure_detector.Agent") +def test_detect_failures_no_failures(mock_agent_cls): + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + mock_result = MagicMock() + mock_result.structured_output = FailureDetectionStructuredOutput(errors=[]) + mock_agent.return_value = mock_result + + session = _make_session() + output = detect_failures(session) + + assert isinstance(output, FailureOutput) + assert output.session_id == "sess_1" + assert output.failures == [] + + +@patch("strands_evals.detectors.failure_detector.Agent") +def test_detect_failures_with_failures(mock_agent_cls): + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + mock_result = MagicMock() + mock_result.structured_output = FailureDetectionStructuredOutput( + errors=[ + FailureError( + location="span_1", + category=["hallucination-category-hall-usage"], + confidence=["high"], + evidence=["Fabricated tool output"], + ), + ] + ) + mock_agent.return_value = mock_result + + session = _make_session() + output = detect_failures(session) + + assert len(output.failures) == 1 + assert output.failures[0].span_id == "span_1" + assert output.failures[0].confidence == ["high"] + + +@patch("strands_evals.detectors.failure_detector.Agent") +def test_detect_failures_confidence_threshold(mock_agent_cls): + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + mock_result = MagicMock() + mock_result.structured_output = FailureDetectionStructuredOutput( + errors=[ + FailureError(location="span_1", category=["err_a"], confidence=["low"], evidence=["weak"]), + FailureError(location="span_2", category=["err_b"], confidence=["high"], evidence=["strong"]), + ] + ) + mock_agent.return_value = mock_result + + session = _make_session() + output = detect_failures(session, confidence_threshold="medium") + + # "low" is below "medium" threshold + assert len(output.failures) == 1 + assert output.failures[0].span_id == "span_2" + + +@patch("strands_evals.detectors.failure_detector.Agent") +def test_detect_failures_context_overflow_fallback(mock_agent_cls): + """When direct call raises context overflow, should fall back to chunking.""" + from strands.types.exceptions import ContextWindowOverflowException + + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + + # First call raises context overflow, subsequent calls succeed (chunking) + chunk_result = MagicMock() + chunk_result.structured_output = FailureDetectionStructuredOutput(errors=[]) + mock_agent.side_effect = [ContextWindowOverflowException("too big"), chunk_result, chunk_result] + + session = _make_session() + output = detect_failures(session) + + assert isinstance(output, FailureOutput) + + +@patch("strands_evals.detectors.failure_detector.Agent") +def test_detect_failures_non_context_error_raises(mock_agent_cls): + """Non-context errors should propagate.""" + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + mock_agent.side_effect = RuntimeError("Something else broke") + + session = _make_session() + with pytest.raises(RuntimeError, match="Something else broke"): + detect_failures(session) + + +@patch("strands_evals.detectors.failure_detector.Agent") +def test_detect_failures_passes_model(mock_agent_cls): + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + mock_result = MagicMock() + mock_result.structured_output = FailureDetectionStructuredOutput(errors=[]) + mock_agent.return_value = mock_result + + session = _make_session() + detect_failures(session, model="us.anthropic.claude-sonnet-4-20250514-v1:0") + + # Verify the Agent was created with the custom model + mock_agent_cls.assert_called_once() + call_kwargs = mock_agent_cls.call_args + assert call_kwargs.kwargs["model"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" diff --git a/tests/strands_evals/detectors/test_types.py b/tests/strands_evals/detectors/test_types.py new file mode 100644 index 0000000..5f427b8 --- /dev/null +++ b/tests/strands_evals/detectors/test_types.py @@ -0,0 +1,219 @@ +"""Tests for detector Pydantic models.""" + +from strands_evals.types.detector import ( + DiagnosisResult, + FailureDetectionStructuredOutput, + FailureError, + FailureItem, + FailureOutput, + FixRecommendation, + RCAItem, + RCAOutput, + RCAStructuredOutput, + RootCauseItem, + SummaryOutput, +) + +# --- Output types --- + + +def test_failure_item_creation(): + item = FailureItem( + span_id="span_123", + category=["hallucination", "incomplete_task"], + confidence=["high", "medium"], + evidence=["Made up product ID", "Did not finish checkout"], + ) + assert item.span_id == "span_123" + assert len(item.category) == 2 + assert item.confidence[0] == "high" + assert item.evidence[1] == "Did not finish checkout" + + +def test_failure_output_empty(): + output = FailureOutput(session_id="sess_1") + assert output.session_id == "sess_1" + assert output.failures == [] + + +def test_failure_output_with_failures(): + item = FailureItem( + span_id="span_1", + category=["error"], + confidence=["high"], + evidence=["timeout"], + ) + output = FailureOutput(session_id="sess_1", failures=[item]) + assert len(output.failures) == 1 + assert output.failures[0].span_id == "span_1" + + +def test_rca_item_creation(): + item = RCAItem( + failure_span_id="span_1", + location="span_0", + causality="PRIMARY_FAILURE", + propagation_impact=["TASK_TERMINATION"], + root_cause_explanation="The tool returned ambiguous results", + fix_type="TOOL_DESCRIPTION_FIX", + fix_recommendation="Add disambiguation instructions", + ) + assert item.failure_span_id == "span_1" + assert item.causality == "PRIMARY_FAILURE" + + +def test_rca_output_empty(): + output = RCAOutput() + assert output.root_causes == [] + + +def test_summary_output(): + output = SummaryOutput( + primary_goal={"description": "Book a flight", "evidence": ["span_1"]}, + approach_taken={"description": "Used search tool", "evidence": ["span_2"]}, + final_outcome={"description": "Successfully booked", "success": True, "evidence": ["span_5"]}, + ) + assert output.primary_goal["description"] == "Book a flight" + assert output.tools_used == [] + assert output.observed_failures == [] + + +def test_diagnosis_result(): + item = FailureItem( + span_id="span_1", + category=["error"], + confidence=["high"], + evidence=["timeout"], + ) + rca = RCAItem( + failure_span_id="span_1", + location="span_0", + causality="PRIMARY_FAILURE", + root_cause_explanation="Timeout due to slow API", + fix_type="OTHERS", + fix_recommendation="Add retry logic", + ) + result = DiagnosisResult( + session_id="sess_1", + user_requests=["Book a flight"], + failures=[item], + summary="The agent attempted to book a flight but timed out.", + root_causes=[rca], + ) + assert result.session_id == "sess_1" + assert len(result.user_requests) == 1 + assert len(result.failures) == 1 + assert result.summary.startswith("The agent") + assert len(result.root_causes) == 1 + + +def test_diagnosis_result_empty(): + result = DiagnosisResult(session_id="sess_1") + assert result.user_requests == [] + assert result.failures == [] + assert result.summary == "" + assert result.root_causes == [] + + +# --- LLM structured output schemas --- + + +def test_failure_error(): + err = FailureError( + location="span_1", + category=["hallucination"], + confidence=["high"], + evidence=["Made up data"], + ) + assert err.location == "span_1" + assert err.confidence == ["high"] + + +def test_failure_detection_structured_output(): + output = FailureDetectionStructuredOutput( + errors=[ + FailureError( + location="span_1", + category=["error"], + confidence=["low"], + evidence=["Timed out"], + ) + ] + ) + assert len(output.errors) == 1 + + +def test_root_cause_item_with_aliases(): + """RootCauseItem should accept both alias names and field names.""" + item = RootCauseItem( + failure_span_id="span_1", + location="span_0", + failure_causality="PRIMARY_FAILURE", + failure_propagation_impact=["TASK_TERMINATION"], + failure_detection_timing="IMMEDIATELY_AT_OCCURRENCE", + completion_status="COMPLETE_FAILURE", + root_cause_explanation="Tool returned bad data", + fix_recommendation=FixRecommendation( + fix_type="TOOL_DESCRIPTION_FIX", + recommendation="Add validation", + ), + ) + assert item.failure_span_id == "span_1" + assert item.failure_causality == "PRIMARY_FAILURE" + + +def test_root_cause_item_from_llm_aliases(): + """RootCauseItem should parse from LLM output using JSON aliases.""" + data = { + "Failure Span ID": "span_1", + "Location": "span_0", + "Failure Causality": "SECONDARY_FAILURE", + "Failure Propagation Impact": ["QUALITY_DEGRADATION"], + "Failure Detection Timing": "SEVERAL_STEPS_LATER", + "Completion Status": "PARTIAL_SUCCESS", + "Root Cause Explanation": "Ambiguous tool output", + "Fix Recommendation": { + "Fix Type": "SYSTEM_PROMPT_FIX", + "Recommendation": "Add disambiguation", + }, + } + item = RootCauseItem.model_validate(data) + assert item.failure_span_id == "span_1" + assert item.failure_causality == "SECONDARY_FAILURE" + assert item.fix_recommendation.fix_type == "SYSTEM_PROMPT_FIX" + + +def test_rca_structured_output(): + output = RCAStructuredOutput( + root_causes=[ + RootCauseItem( + failure_span_id="span_1", + location="span_0", + failure_causality="PRIMARY_FAILURE", + failure_propagation_impact=["TASK_TERMINATION"], + failure_detection_timing="IMMEDIATELY_AT_OCCURRENCE", + completion_status="COMPLETE_FAILURE", + root_cause_explanation="Error", + fix_recommendation=FixRecommendation( + fix_type="OTHERS", + recommendation="Fix it", + ), + ) + ] + ) + assert len(output.root_causes) == 1 + + +def test_failure_output_serialization_roundtrip(): + """Test that models can serialize and deserialize.""" + item = FailureItem( + span_id="span_1", + category=["hallucination"], + confidence=["high"], + evidence=["Made up data"], + ) + output = FailureOutput(session_id="sess_1", failures=[item]) + json_str = output.model_dump_json() + restored = FailureOutput.model_validate_json(json_str) + assert restored.session_id == "sess_1" + assert restored.failures[0].confidence[0] == "high"