Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ build
.venv
.mypy_cache
.ruff_cache
.claude
*.bak
.vscode
dist
Expand All @@ -15,4 +16,4 @@ dataset_files
report_files
.venv
*.DS_Store*
uv.lock
uv.lock
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ select = [

[tool.ruff.lint.per-file-ignores]
"src/strands_evals/evaluators/prompt_templates/*" = ["E501"] # line-length
"src/strands_evals/detectors/prompt_templates/*" = ["E501"] # line-length
"src/strands_evals/generators/prompt_template/*" = ["E501"] # line-length

[tool.mypy]
Expand Down
3 changes: 2 additions & 1 deletion src/strands_evals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import evaluators, extractors, generators, providers, simulation, telemetry, types
from . import detectors, evaluators, extractors, generators, providers, simulation, telemetry, types
from .case import Case
from .experiment import Experiment
from .simulation import ActorSimulator, UserSimulator
Expand All @@ -7,6 +7,7 @@
__all__ = [
"Experiment",
"Case",
"detectors",
"evaluators",
"extractors",
"providers",
Expand Down
32 changes: 32 additions & 0 deletions src/strands_evals/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Detectors for analyzing agent execution traces.

Detectors answer "why did my agent behave this way?" by analyzing Session
traces for failures, summarizing execution, extracting user requests, and
performing root cause analysis.
"""

from ..types.detector import (
ConfidenceLevel,
DiagnosisResult,
FailureDetectionStructuredOutput,
FailureItem,
FailureOutput,
RCAItem,
RCAOutput,
RCAStructuredOutput,
)
from .failure_detector import detect_failures

__all__ = [
# Core detectors
"detect_failures",
# Types
"ConfidenceLevel",
"DiagnosisResult",
"FailureOutput",
"FailureItem",
"FailureDetectionStructuredOutput",
"RCAOutput",
"RCAItem",
"RCAStructuredOutput",
]
124 changes: 124 additions & 0 deletions src/strands_evals/detectors/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Shared utilities for context window management and span chunking.

Used by detect_failures, summarize_execution, and analyze_root_cause
to handle sessions that exceed LLM context limits.

Ported from AgentCoreLens failure_detector._split_spans_by_tokens() and
related helpers, replacing litellm.token_counter() with a conservative
character-based estimate.
"""

import json
import logging

from ..types.detector import ConfidenceLevel, FailureItem
from ..types.trace import SpanUnion

logger = logging.getLogger(__name__)

CHARS_PER_TOKEN = 4
CONTEXT_SAFETY_MARGIN = 0.75
DEFAULT_MAX_INPUT_TOKENS = 128_000
CHUNK_OVERLAP_SPANS = 2
MIN_CHUNK_SIZE = 5
_CONFIDENCE_RANK: dict[ConfidenceLevel, int] = {"low": 0, "medium": 1, "high": 2}


def estimate_tokens(text: str) -> int:
"""Conservative token estimate from text length."""
return len(text) // CHARS_PER_TOKEN


def would_exceed_context(prompt_text: str, max_input_tokens: int = DEFAULT_MAX_INPUT_TOKENS) -> bool:
"""Pre-flight check: would this prompt likely exceed context?

If wrong (false negative), the model returns a context-exceeded error,
which triggers the chunking fallback in each detector.
"""
return estimate_tokens(prompt_text) > int(max_input_tokens * CONTEXT_SAFETY_MARGIN)


def split_spans_by_tokens(
spans: list[SpanUnion],
max_tokens: int = DEFAULT_MAX_INPUT_TOKENS,
overlap_spans: int = CHUNK_OVERLAP_SPANS,
) -> list[list[SpanUnion]]:
"""Split spans into chunks fitting within token limits.

Adjacent chunks share ``overlap_spans`` spans for context continuity.
Ported from Lens ``failure_detector._split_spans_by_tokens()``.

Args:
spans: Flat list of span objects to chunk.
max_tokens: Maximum model input tokens.
overlap_spans: Number of spans shared between adjacent chunks.

Returns:
List of span chunks. Each chunk fits within the effective token limit.
"""
effective_limit = int(max_tokens * CONTEXT_SAFETY_MARGIN)
chunks: list[list[SpanUnion]] = []
current_chunk: list[SpanUnion] = []
current_tokens = 0

for span in spans:
span_tokens = estimate_tokens(json.dumps(span.model_dump(), default=str))
if current_tokens + span_tokens > effective_limit and len(current_chunk) >= MIN_CHUNK_SIZE:
chunks.append(current_chunk)
overlap = current_chunk[-overlap_spans:] if overlap_spans > 0 else []
current_chunk = list(overlap)
current_tokens = sum(estimate_tokens(json.dumps(s.model_dump(), default=str)) for s in overlap)
current_chunk.append(span)
current_tokens += span_tokens

if current_chunk:
chunks.append(current_chunk)

logger.info(
"Split %d spans into %d chunks (max_tokens=%d, overlap=%d)",
len(spans),
len(chunks),
max_tokens,
overlap_spans,
)
return chunks


def merge_chunk_failures(chunk_results: list[list[FailureItem]]) -> list[FailureItem]:
"""Merge failures from overlapping chunks, deduplicating by span_id.

Keeps highest confidence per category when the same span_id appears
in multiple chunks. Ported from Lens ``_merge_chunk_failures()``.

Args:
chunk_results: List of failure lists, one per chunk.

Returns:
Deduplicated list of FailureItems.
"""
seen: dict[str, FailureItem] = {}
for chunk in chunk_results:
for failure in chunk:
if failure.span_id not in seen:
# Deep copy to avoid mutating originals
seen[failure.span_id] = FailureItem(
span_id=failure.span_id,
category=list(failure.category),
confidence=list(failure.confidence),
evidence=list(failure.evidence),
)
else:
existing = seen[failure.span_id]
for i, cat in enumerate(failure.category):
if cat in existing.category:
idx = existing.category.index(cat)
if _CONFIDENCE_RANK.get(failure.confidence[i], 0) > _CONFIDENCE_RANK.get(
existing.confidence[idx], 0
):
existing.confidence[idx] = failure.confidence[i]
existing.evidence[idx] = failure.evidence[i]
else:
existing.category.append(cat)
existing.confidence.append(failure.confidence[i])
existing.evidence.append(failure.evidence[i])
return list(seen.values())
172 changes: 172 additions & 0 deletions src/strands_evals/detectors/failure_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Failure detection for agent execution sessions.

Identifies semantic failures (hallucinations, task errors, tool misuse, etc.)
in Session traces using LLM-based analysis with automatic chunking fallback
for sessions exceeding context limits.

Ported from AgentCoreLens tools/failure_detector.py.
"""

import json
import logging

from strands import Agent
from strands.models.model import Model
from strands.types.exceptions import ContextWindowOverflowException
from typing_extensions import Union, cast

from ..types.detector import ConfidenceLevel, FailureDetectionStructuredOutput, FailureItem, FailureOutput
from ..types.trace import Session, SpanUnion
from .chunking import merge_chunk_failures, split_spans_by_tokens, would_exceed_context
from .prompt_templates.failure_detection import get_template

logger = logging.getLogger(__name__)

DEFAULT_DETECTOR_MODEL = "us.anthropic.claude-haiku-4-5-20251001-v1:0"
CONFIDENCE_ORDER: dict[ConfidenceLevel, int] = {"low": 0, "medium": 1, "high": 2}


def detect_failures(
session: Session,
*,
confidence_threshold: ConfidenceLevel = "low",
model: Union[Model, str, None] = None,
) -> FailureOutput:
"""Detect semantic failures in an agent execution session.

Args:
session: The Session object to analyze.
confidence_threshold: Minimum confidence level ("low", "medium", "high")
to include a failure. Defaults to "low" (include all).
model: Any Strands model provider. None uses default Haiku.

Returns:
FailureOutput with list of FailureItems, each with span_id, category,
confidence, and evidence.
"""
effective_model = model if model is not None else DEFAULT_DETECTOR_MODEL
template = get_template("v0")
session_json = _serialize_session(session)
user_prompt = template.build_prompt(session_json=session_json)

if would_exceed_context(user_prompt):
raw = _detect_chunked(session, effective_model, template)
else:
try:
raw = _detect_direct(user_prompt, effective_model, template)
except Exception as e:
if _is_context_exceeded(e):
logger.warning("Context exceeded despite pre-flight check, falling back to chunking")
raw = _detect_chunked(session, effective_model, template)
else:
raise

threshold_rank = CONFIDENCE_ORDER[confidence_threshold]
filtered = [f for f in raw if _max_confidence_rank(f) >= threshold_rank]
return FailureOutput(session_id=session.session_id, failures=filtered)


def _detect_direct(user_prompt: str, model: Union[Model, str], template: object) -> list[FailureItem]:
"""Attempt direct LLM detection on the full session."""
agent = Agent(model=model, system_prompt=template.SYSTEM_PROMPT, callback_handler=None)
result = agent(user_prompt, structured_output_model=FailureDetectionStructuredOutput)
return _parse_structured_result(cast(FailureDetectionStructuredOutput, result.structured_output))


def _detect_chunked(
session: Session,
model: Union[Model, str],
template: object,
) -> list[FailureItem]:
"""Chunk session and detect failures per chunk, then merge."""
spans = _flatten_traces_to_spans(session.traces)
chunks = split_spans_by_tokens(spans)

logger.info("Chunked detection: %d spans -> %d chunks", len(spans), len(chunks))

chunk_results: list[list[FailureItem]] = []
for i, chunk_spans in enumerate(chunks):
try:
chunk_json = _serialize_spans(chunk_spans, session.session_id)
user_prompt = template.build_prompt(session_json=chunk_json)
agent = Agent(model=model, system_prompt=template.SYSTEM_PROMPT, callback_handler=None)
result = agent(user_prompt, structured_output_model=FailureDetectionStructuredOutput)
chunk_results.append(
_parse_structured_result(cast(FailureDetectionStructuredOutput, result.structured_output))
)
logger.info("Chunk %d/%d: processed %d spans", i + 1, len(chunks), len(chunk_spans))
except Exception as e:
if _is_context_exceeded(e):
logger.warning("Chunk %d/%d still exceeds context, skipping", i + 1, len(chunks))
else:
raise

return merge_chunk_failures(chunk_results)


def _parse_structured_result(output: FailureDetectionStructuredOutput) -> list[FailureItem]:
"""Convert LLM structured output to list[FailureItem]."""
items = []
for err in output.errors:
# Validate element-wise correspondence
if not (len(err.category) == len(err.evidence) == len(err.confidence)):
logger.warning(
"Mismatched array lengths for location %s: category=%d, evidence=%d, confidence=%d. Skipping.",
err.location,
len(err.category),
len(err.evidence),
len(err.confidence),
)
continue

items.append(
FailureItem(
span_id=err.location,
category=list(err.category),
confidence=list(err.confidence),
evidence=list(err.evidence),
)
)
return items


def _max_confidence_rank(item: FailureItem) -> int:
"""Return the maximum confidence rank across all failure modes."""
if not item.confidence:
return -1
return max(CONFIDENCE_ORDER.get(c, 0) for c in item.confidence)


def _is_context_exceeded(exception: Exception) -> bool:
"""Check if the exception indicates a context window overflow.

Strands Agent raises ContextWindowOverflowException for context overflows.
Also checks error message strings as a fallback for providers that raise
generic errors (e.g., ValidationException from Bedrock).
"""
if isinstance(exception, ContextWindowOverflowException):
return True
msg = str(exception).lower()
return any(p in msg for p in ["context", "too long", "max_tokens", "input_limit"])


def _flatten_traces_to_spans(traces: list) -> list[SpanUnion]:
"""Flatten all traces into a single span list."""
return [span for trace in traces for span in trace.spans]


def _serialize_session(session: Session) -> str:
"""Serialize a full Session to JSON for the prompt."""
return json.dumps(session.model_dump(), indent=2, default=str)


def _serialize_spans(spans: list[SpanUnion], session_id: str) -> str:
"""Serialize a list of spans as a minimal session JSON for chunk prompts."""
return json.dumps(
{
"session_id": session_id,
"traces": [{"spans": [s.model_dump() for s in spans]}],
},
indent=2,
default=str,
)
5 changes: 5 additions & 0 deletions src/strands_evals/detectors/prompt_templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Prompt templates for detectors.

Each subdirectory contains versioned prompt modules following the same
pattern as evaluators/prompt_templates/.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from . import failure_detection_v0

VERSIONS = {
"v0": failure_detection_v0,
}

DEFAULT_VERSION = "v0"


def get_template(version: str = DEFAULT_VERSION):
if version not in VERSIONS:
raise ValueError(f"Unknown version: {version}. Available: {list(VERSIONS.keys())}")
return VERSIONS[version]
Loading
Loading