diff --git a/pyproject.toml b/pyproject.toml
index 511b56e..d557f65 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,7 @@ authors = [
dependencies = [
"pydantic>=2.0.0,<3.0.0",
"rich>=14.0.0,<15.0.0",
- "strands-agents>=1.0.0",
+ "strands-agents>=1.28.0",
"strands-agents-tools>=0.1.0,<1.0.0",
"typing-extensions>=4.0",
"opentelemetry-api>=1.20.0",
@@ -128,6 +128,7 @@ select = [
[tool.ruff.lint.per-file-ignores]
"src/strands_evals/evaluators/prompt_templates/*" = ["E501"] # line-length
"src/strands_evals/generators/prompt_template/*" = ["E501"] # line-length
+"src/strands_evals/plugins/prompt_templates/*" = ["E501"] # line-length
[tool.mypy]
# Disable strict checks that cause false positives with Generic classes
diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py
index f5c600c..37994ae 100644
--- a/src/strands_evals/__init__.py
+++ b/src/strands_evals/__init__.py
@@ -1,4 +1,4 @@
-from . import evaluators, extractors, generators, providers, simulation, telemetry, types
+from . import evaluators, extractors, generators, plugins, providers, simulation, telemetry, types
from .case import Case
from .experiment import Experiment
from .simulation import ActorSimulator, UserSimulator
@@ -12,6 +12,7 @@
"providers",
"types",
"generators",
+ "plugins",
"simulation",
"telemetry",
"StrandsEvalsTelemetry",
diff --git a/src/strands_evals/plugins/__init__.py b/src/strands_evals/plugins/__init__.py
new file mode 100644
index 0000000..ca39e35
--- /dev/null
+++ b/src/strands_evals/plugins/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation_plugin import EvaluationPlugin
+
+__all__ = ["EvaluationPlugin"]
diff --git a/src/strands_evals/plugins/evaluation_plugin.py b/src/strands_evals/plugins/evaluation_plugin.py
new file mode 100644
index 0000000..c59a8f9
--- /dev/null
+++ b/src/strands_evals/plugins/evaluation_plugin.py
@@ -0,0 +1,205 @@
+"""Plugin that evaluates agent invocations and retries with improvements on failure."""
+
+import logging
+from typing import Any, cast
+
+from pydantic import BaseModel
+from strands import Agent
+from strands.models import Model
+from strands.plugins.plugin import Plugin
+
+from strands_evals.evaluators.evaluator import Evaluator
+from strands_evals.plugins.prompt_templates.improvement_suggestion import (
+ IMPROVEMENT_SYSTEM_PROMPT,
+ compose_improvement_prompt,
+)
+from strands_evals.types.evaluation import EvaluationData, EvaluationOutput
+
+logger = logging.getLogger(__name__)
+
+
+class ImprovementSuggestion(BaseModel):
+ """Structured output from the improvement suggestion LLM."""
+
+ reasoning: str
+ system_prompt: str
+
+
+class EvaluationPlugin(Plugin):
+ """Evaluates agent output after each invocation and retries with improved system prompts on failure."""
+
+ @property
+ def name(self) -> str:
+ return "strands-evals"
+
+ def __init__(
+ self,
+ evaluators: list[Evaluator],
+ max_retries: int = 1,
+ expected_output: Any = None,
+ expected_trajectory: list[Any] | None = None,
+ model: Model | str | None = None,
+ ):
+ """Initialize the evaluation plugin.
+
+ Args:
+ evaluators: Evaluators to run against agent output after each invocation.
+ max_retries: Maximum number of retry attempts when evaluation fails.
+ expected_output: Default expected output for evaluation. Can be overridden per-invocation
+ via ``invocation_state``.
+ expected_trajectory: Default expected trajectory for evaluation. Can be overridden
+ per-invocation via ``invocation_state``.
+ model: Model used by the improvement suggestion agent. Accepts a Model instance,
+ a model ID string, or None to use the default.
+ """
+ self._evaluators = evaluators
+ self._max_retries = max_retries
+ self._expected_output = expected_output
+ self._expected_trajectory = expected_trajectory
+ self._model = model
+ self._agent: Any = None
+
+ def init_agent(self, agent: Any) -> None:
+ """Wrap the agent's ``__call__`` to intercept invocations for evaluation and retry.
+
+ Creates a dynamic subclass of the agent's class with a wrapped ``__call__`` that runs
+ evaluators after each invocation and retries with an improved system prompt on failure.
+
+ Args:
+ agent: The agent instance whose invocations will be evaluated.
+ """
+ self._agent = agent
+ original_call = agent.__class__.__call__
+ plugin = self
+
+ def wrapped_call(self_agent: Any, prompt: Any = None, **kwargs: Any) -> Any:
+ return plugin._invoke_with_evaluation(self_agent, original_call, prompt, **kwargs)
+
+ wrapped_class = type(
+ agent.__class__.__name__,
+ (agent.__class__,),
+ {"__call__": wrapped_call},
+ )
+ agent.__class__ = wrapped_class
+
+ def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, **kwargs: Any) -> Any:
+ """Run the agent, evaluate output, and retry with an improved system prompt on failure.
+
+ Restores the original system prompt and messages after all attempts complete.
+
+ Args:
+ agent: The agent instance being invoked.
+ original_call: The unwrapped ``__call__`` method.
+ prompt: The user prompt passed to the agent.
+ **kwargs: Additional keyword arguments forwarded to the agent call.
+
+ Returns:
+ The result from the last agent invocation attempt.
+ """
+ original_system_prompt = agent.system_prompt
+ original_messages = list(agent.messages)
+ invocation_state = kwargs.get("invocation_state", {})
+
+ for attempt in range(1 + self._max_retries):
+ if attempt > 0:
+ agent.messages = list(original_messages)
+
+ result = original_call(agent, prompt, **kwargs)
+
+ evaluation_data = self._build_evaluation_data(prompt, result, invocation_state)
+ outputs = self._run_evaluators(evaluation_data)
+ all_pass = all(o.test_pass for o in outputs)
+
+ if all_pass or attempt == self._max_retries:
+ break
+
+ logger.debug(
+ "attempt=<%s>, evaluation_pass=<%s> | evaluation failed, generating improvements",
+ attempt + 1,
+ all_pass,
+ )
+
+ expected_output = evaluation_data.expected_output
+ suggestion = self._suggest_improvements(prompt, str(result), outputs, agent.system_prompt, expected_output)
+ logger.debug(
+ "attempt=<%s>, reasoning=<%s> | applying improved system prompt", attempt + 1, suggestion.reasoning
+ )
+ agent.system_prompt = suggestion.system_prompt
+
+ agent.system_prompt = original_system_prompt
+ return result
+
+ def _build_evaluation_data(self, prompt: Any, result: Any, invocation_state: dict) -> EvaluationData:
+ """Assemble evaluation data from the invocation context.
+
+ Args:
+ prompt: The user prompt.
+ result: The agent's output.
+ invocation_state: Per-invocation overrides for expected values.
+
+ Returns:
+ An EvaluationData instance ready for evaluator consumption.
+ """
+ expected_output = invocation_state.get("expected_output", self._expected_output)
+ expected_trajectory = invocation_state.get("expected_trajectory", self._expected_trajectory)
+
+ return EvaluationData(
+ input=prompt,
+ actual_output=str(result),
+ expected_output=expected_output,
+ expected_trajectory=expected_trajectory,
+ )
+
+ def _suggest_improvements(
+ self,
+ prompt: Any,
+ actual_output: str,
+ outputs: list[EvaluationOutput],
+ current_system_prompt: str | None,
+ expected_output: Any = None,
+ ) -> ImprovementSuggestion:
+ """Ask an LLM to suggest an improved system prompt based on evaluation failures.
+
+ Args:
+ prompt: The original user prompt.
+ actual_output: The agent's output as a string.
+ outputs: Evaluation outputs from the failed attempt.
+ current_system_prompt: The agent's current system prompt.
+ expected_output: The expected output, if available.
+
+ Returns:
+ An ImprovementSuggestion containing the reasoning and a revised system prompt.
+ """
+ failure_reasons = [o.reason for o in outputs if not o.test_pass and o.reason]
+ improvement_prompt = compose_improvement_prompt(
+ user_prompt=str(prompt),
+ actual_output=actual_output,
+ failure_reasons=failure_reasons,
+ current_system_prompt=current_system_prompt,
+ expected_output=str(expected_output) if expected_output is not None else None,
+ )
+ suggestion_agent = Agent(model=self._model, system_prompt=IMPROVEMENT_SYSTEM_PROMPT, callback_handler=None)
+ result = suggestion_agent(improvement_prompt, structured_output_model=ImprovementSuggestion)
+ return cast(ImprovementSuggestion, result.structured_output)
+
+ def _run_evaluators(self, evaluation_data: EvaluationData) -> list[EvaluationOutput]:
+ """Run all evaluators against the given evaluation data.
+
+ Exceptions raised by individual evaluators are caught, logged, and recorded as failures
+ so that a single broken evaluator does not prevent the others from running.
+
+ Args:
+ evaluation_data: The data to evaluate.
+
+ Returns:
+ A list of evaluation outputs from all evaluators.
+ """
+ all_outputs: list[EvaluationOutput] = []
+ for evaluator in self._evaluators:
+ try:
+ outputs = evaluator.evaluate(evaluation_data)
+ all_outputs.extend(outputs)
+ except Exception:
+ logger.exception("evaluator=<%s> | evaluator raised an exception", type(evaluator).__name__)
+ all_outputs.append(EvaluationOutput(score=0.0, test_pass=False, reason="evaluator raised an exception"))
+ return all_outputs
diff --git a/src/strands_evals/plugins/prompt_templates/__init__.py b/src/strands_evals/plugins/prompt_templates/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py b/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py
new file mode 100644
index 0000000..ac05a22
--- /dev/null
+++ b/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py
@@ -0,0 +1,37 @@
+IMPROVEMENT_SYSTEM_PROMPT = """You are an expert at analyzing AI agent failures and suggesting system prompt improvements.
+
+Given an agent's current system prompt, the user's request, the agent's output, and evaluation failures, suggest a modified system prompt that addresses the failures while preserving the agent's core capabilities.
+
+Focus on:
+- Adding specific instructions that address the evaluation failure reasons
+- Preserving existing useful instructions from the current system prompt
+- Being concise and actionable
+- Not changing the fundamental purpose of the agent
+
+Return the complete improved system prompt that should replace the current one."""
+
+
+def compose_improvement_prompt(
+ user_prompt: str,
+ actual_output: str,
+ failure_reasons: list[str],
+ current_system_prompt: str | None,
+ expected_output: str | None = None,
+) -> str:
+ parts = []
+
+ if current_system_prompt:
+ parts.append(f"{current_system_prompt}")
+ else:
+ parts.append("No system prompt set")
+
+ parts.append(f"{user_prompt}")
+ parts.append(f"{actual_output}")
+
+ if expected_output is not None:
+ parts.append(f"{expected_output}")
+
+ reasons_text = "\n".join(f"- {reason}" for reason in failure_reasons)
+ parts.append(f"\n{reasons_text}\n")
+
+ return "\n\n".join(parts)
diff --git a/tests/strands_evals/plugins/__init__.py b/tests/strands_evals/plugins/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/strands_evals/plugins/test_evaluation_plugin.py b/tests/strands_evals/plugins/test_evaluation_plugin.py
new file mode 100644
index 0000000..9533fe6
--- /dev/null
+++ b/tests/strands_evals/plugins/test_evaluation_plugin.py
@@ -0,0 +1,480 @@
+from unittest.mock import Mock, patch
+
+import pytest
+
+from strands_evals.plugins import EvaluationPlugin
+from strands_evals.plugins.evaluation_plugin import ImprovementSuggestion
+from strands_evals.types import EvaluationData, EvaluationOutput
+
+
+@pytest.fixture
+def mock_evaluator():
+ evaluator = Mock()
+ return evaluator
+
+
+class FakeAgent:
+ """Minimal agent class for testing __class__ swap without real Agent internals."""
+
+ def __init__(self):
+ self.system_prompt = "original prompt"
+ self.messages: list = []
+ self._result = Mock()
+ self._result.__str__ = Mock(return_value="mock output")
+ self.call_count = 0
+
+ def __call__(self, prompt=None, **kwargs):
+ self.call_count += 1
+ return self._result
+
+
+@pytest.fixture
+def mock_agent():
+ return FakeAgent()
+
+
+def test_plugin_name(mock_evaluator):
+ plugin = EvaluationPlugin(evaluators=[mock_evaluator])
+ assert plugin.name == "strands-evals"
+
+
+def test_init_with_defaults(mock_evaluator):
+ plugin = EvaluationPlugin(evaluators=[mock_evaluator])
+ assert plugin._evaluators == [mock_evaluator]
+ assert plugin._max_retries == 1
+ assert plugin._expected_output is None
+ assert plugin._expected_trajectory is None
+ assert plugin._model is None
+
+
+def test_init_with_custom_values(mock_evaluator):
+ plugin = EvaluationPlugin(
+ evaluators=[mock_evaluator],
+ max_retries=3,
+ expected_output="expected",
+ expected_trajectory=["step1", "step2"],
+ model="us.anthropic.claude-sonnet-4-20250514-v1:0",
+ )
+ assert plugin._evaluators == [mock_evaluator]
+ assert plugin._max_retries == 3
+ assert plugin._expected_output == "expected"
+ assert plugin._expected_trajectory == ["step1", "step2"]
+ assert plugin._model == "us.anthropic.claude-sonnet-4-20250514-v1:0"
+
+
+# --- Step 2: init_agent + __class__ swap ---
+
+
+def test_init_agent_wraps_call(mock_evaluator, mock_agent):
+ """After init_agent, agent.__class__ should be a subclass of the original."""
+ original_class = mock_agent.__class__
+ plugin = EvaluationPlugin(evaluators=[mock_evaluator])
+ plugin.init_agent(mock_agent)
+ assert mock_agent.__class__ is not original_class
+ assert issubclass(mock_agent.__class__, original_class)
+
+
+def test_init_agent_preserves_isinstance(mock_evaluator, mock_agent):
+ """isinstance(agent, OriginalClass) should still be True after wrapping."""
+ plugin = EvaluationPlugin(evaluators=[mock_evaluator])
+ plugin.init_agent(mock_agent)
+ assert isinstance(mock_agent, FakeAgent)
+
+
+def test_wrapped_call_invokes_original(mock_evaluator, mock_agent):
+ """Calling agent(prompt) after wrapping should still invoke the original agent logic."""
+ mock_evaluator.evaluate.return_value = [Mock(test_pass=True)]
+ plugin = EvaluationPlugin(evaluators=[mock_evaluator], max_retries=0)
+ plugin.init_agent(mock_agent)
+
+ result = mock_agent("test prompt")
+
+ assert result is not None
+
+
+# --- Step 3: Evaluation execution on invocation ---
+
+
+def test_evaluators_run_after_invocation(mock_agent):
+ """Evaluators should be called with EvaluationData after invocation."""
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True, reason="good")]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=0)
+ plugin.init_agent(mock_agent)
+
+ mock_agent("What is 2+2?")
+
+ evaluator.evaluate.assert_called_once()
+ eval_data = evaluator.evaluate.call_args[0][0]
+ assert isinstance(eval_data, EvaluationData)
+ assert eval_data.input == "What is 2+2?"
+ assert eval_data.actual_output == str(mock_agent._result)
+
+
+def test_evaluation_data_uses_constructor_expected(mock_agent):
+ """EvaluationData should use expected values from constructor."""
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)]
+ plugin = EvaluationPlugin(
+ evaluators=[evaluator],
+ max_retries=0,
+ expected_output="4",
+ expected_trajectory=["calculator"],
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("What is 2+2?")
+
+ eval_data = evaluator.evaluate.call_args[0][0]
+ assert eval_data.expected_output == "4"
+ assert eval_data.expected_trajectory == ["calculator"]
+
+
+def test_evaluation_data_uses_invocation_state_expected(mock_agent):
+ """invocation_state expected values should override constructor values."""
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)]
+ plugin = EvaluationPlugin(
+ evaluators=[evaluator],
+ max_retries=0,
+ expected_output="constructor_value",
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt", invocation_state={"expected_output": "state_value"})
+
+ eval_data = evaluator.evaluate.call_args[0][0]
+ assert eval_data.expected_output == "state_value"
+
+
+def test_passing_evaluation_returns_immediately(mock_agent):
+ """When evaluations pass, result should be returned without retry."""
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=3)
+ plugin.init_agent(mock_agent)
+
+ result = mock_agent("prompt")
+
+ assert result is mock_agent._result
+ evaluator.evaluate.assert_called_once()
+
+
+# --- Step 4: Retry on failure ---
+
+
+def test_retry_on_failure(mock_agent):
+ """Agent should be re-invoked when evaluation fails."""
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="bad")],
+ [EvaluationOutput(score=1.0, test_pass=True, reason="good")],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt")
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ assert mock_agent.call_count == 2
+ assert evaluator.evaluate.call_count == 2
+
+
+def test_max_retries_respected(mock_agent):
+ """Should stop retrying after max_retries attempts."""
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False, reason="always bad")]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=2)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt")
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ # 1 initial + 2 retries = 3 total calls
+ assert mock_agent.call_count == 3
+ assert evaluator.evaluate.call_count == 3
+
+
+def test_max_retries_zero_no_retry(mock_agent):
+ """No retry when max_retries=0."""
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=0)
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ assert mock_agent.call_count == 1
+ assert evaluator.evaluate.call_count == 1
+
+
+def test_messages_reset_between_retries(mock_agent):
+ """Agent messages should be restored to pre-invocation state before each retry."""
+ messages_during_calls = []
+
+ original_messages = ["pre-existing"]
+ mock_agent.messages = list(original_messages)
+
+ original_call = mock_agent.__class__.__call__
+
+ def tracking_call(self, prompt=None, **kwargs):
+ messages_during_calls.append(list(self.messages))
+ self.messages.append(f"response-{self.call_count}")
+ return original_call(self, prompt, **kwargs)
+
+ mock_agent.__class__.__call__ = tracking_call
+
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt")
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ # Both calls should start with the original messages
+ assert messages_during_calls[0] == original_messages
+ assert messages_during_calls[1] == original_messages
+
+
+def test_system_prompt_restored_after_all_attempts(mock_agent):
+ """Original system prompt should be restored after retries complete."""
+ original_prompt = mock_agent.system_prompt
+ evaluator = Mock()
+ evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt")
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ assert mock_agent.system_prompt == original_prompt
+
+
+# --- Step 5: Improvement suggestion generation ---
+
+
+@patch("strands_evals.plugins.evaluation_plugin.Agent")
+def test_improvement_suggestion_called_on_failure(mock_agent_class, mock_agent):
+ """LLM should be called to generate suggestions when evaluation fails."""
+ mock_suggestion_agent = Mock()
+ mock_suggestion_result = Mock()
+ mock_suggestion_result.structured_output = ImprovementSuggestion(
+ reasoning="Output was wrong", system_prompt="Be more careful"
+ )
+ mock_suggestion_agent.return_value = mock_suggestion_result
+ mock_agent_class.return_value = mock_suggestion_agent
+
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="incorrect")],
+ [EvaluationOutput(score=1.0, test_pass=True, reason="correct")],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1, model="test-model")
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ mock_agent_class.assert_called_once()
+ mock_suggestion_agent.assert_called_once()
+
+
+@patch("strands_evals.plugins.evaluation_plugin.Agent")
+def test_improvement_suggestion_prompt_contains_failures(mock_agent_class, mock_agent):
+ """Improvement prompt should include evaluation failure reasons."""
+ mock_suggestion_agent = Mock()
+ mock_suggestion_result = Mock()
+ mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="Needs fix", system_prompt="improved")
+ mock_suggestion_agent.return_value = mock_suggestion_result
+ mock_agent_class.return_value = mock_suggestion_agent
+
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="answer was factually wrong")],
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ prompt_arg = mock_suggestion_agent.call_args[0][0]
+ assert "answer was factually wrong" in prompt_arg
+ assert "original prompt" in prompt_arg
+
+
+@patch("strands_evals.plugins.evaluation_plugin.Agent")
+def test_improvement_suggestion_prompt_contains_expected_output(mock_agent_class, mock_agent):
+ """Improvement prompt should include expected_output when available."""
+ mock_suggestion_agent = Mock()
+ mock_suggestion_result = Mock()
+ mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="fix", system_prompt="improved")
+ mock_suggestion_agent.return_value = mock_suggestion_result
+ mock_agent_class.return_value = mock_suggestion_agent
+
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="does not match")],
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1, expected_output="Paris")
+ plugin.init_agent(mock_agent)
+
+ mock_agent("What is the capital of France?")
+
+ prompt_arg = mock_suggestion_agent.call_args[0][0]
+ assert "Paris" in prompt_arg
+ assert "ExpectedOutput" in prompt_arg
+
+
+@patch("strands_evals.plugins.evaluation_plugin.Agent")
+def test_improvement_suggestion_prompt_omits_expected_output_when_none(mock_agent_class, mock_agent):
+ """Improvement prompt should not include ExpectedOutput tag when no expected_output is set."""
+ mock_suggestion_agent = Mock()
+ mock_suggestion_result = Mock()
+ mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="fix", system_prompt="improved")
+ mock_suggestion_agent.return_value = mock_suggestion_result
+ mock_agent_class.return_value = mock_suggestion_agent
+
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="bad")],
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ prompt_arg = mock_suggestion_agent.call_args[0][0]
+ assert "ExpectedOutput" not in prompt_arg
+
+
+@patch("strands_evals.plugins.evaluation_plugin.Agent")
+def test_improved_system_prompt_applied_before_retry(mock_agent_class, mock_agent):
+ """Agent system prompt should be updated with suggestion before retry."""
+ mock_suggestion_agent = Mock()
+ mock_suggestion_result = Mock()
+ mock_suggestion_result.structured_output = ImprovementSuggestion(
+ reasoning="Needs specificity", system_prompt="Always answer with the city name only"
+ )
+ mock_suggestion_agent.return_value = mock_suggestion_result
+ mock_agent_class.return_value = mock_suggestion_agent
+
+ system_prompts_during_calls = []
+ original_call = mock_agent.__class__.__call__
+
+ def tracking_call(self, prompt=None, **kwargs):
+ system_prompts_during_calls.append(self.system_prompt)
+ return original_call(self, prompt, **kwargs)
+
+ mock_agent.__class__.__call__ = tracking_call
+
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="bad")],
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ assert system_prompts_during_calls[0] == "original prompt"
+ assert system_prompts_during_calls[1] == "Always answer with the city name only"
+
+
+# --- Step 7: Multiple evaluators + edge cases ---
+
+
+def test_multiple_evaluators_all_must_pass(mock_agent):
+ """All evaluators must pass for the invocation to be considered successful."""
+ evaluator1 = Mock()
+ evaluator1.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)]
+ evaluator2 = Mock()
+ evaluator2.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)]
+ plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=0)
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ evaluator1.evaluate.assert_called_once()
+ evaluator2.evaluate.assert_called_once()
+ assert mock_agent.call_count == 1
+
+
+def test_partial_evaluator_failure_triggers_retry(mock_agent):
+ """If any evaluator fails, retry should be triggered."""
+ evaluator1 = Mock()
+ evaluator1.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)]
+ evaluator2 = Mock()
+ evaluator2.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False, reason="failed")],
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=1)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved")
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ assert mock_agent.call_count == 2
+
+
+def test_evaluator_exception_recorded_as_failure(mock_agent):
+ """Evaluator exceptions should be caught and treated as failures."""
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ RuntimeError("evaluator crashed"),
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved")
+ )
+ plugin.init_agent(mock_agent)
+
+ mock_agent("prompt")
+
+ assert mock_agent.call_count == 2
+
+
+def test_returns_final_attempt_result(mock_agent):
+ """Should return the result from the final attempt."""
+ results = [Mock(__str__=Mock(return_value="first")), Mock(__str__=Mock(return_value="second"))]
+ call_idx = [0]
+
+ original_call = mock_agent.__class__.__call__
+
+ def multi_result_call(self, prompt=None, **kwargs):
+ idx = call_idx[0]
+ call_idx[0] += 1
+ original_call(self, prompt, **kwargs)
+ return results[idx]
+
+ mock_agent.__class__.__call__ = multi_result_call
+
+ evaluator = Mock()
+ evaluator.evaluate.side_effect = [
+ [EvaluationOutput(score=0.0, test_pass=False)],
+ [EvaluationOutput(score=1.0, test_pass=True)],
+ ]
+ plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1)
+ plugin._suggest_improvements = Mock(
+ return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved")
+ )
+ plugin.init_agent(mock_agent)
+
+ result = mock_agent("prompt")
+
+ assert result is results[1]