diff --git a/pyproject.toml b/pyproject.toml index 511b56e..d557f65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ authors = [ dependencies = [ "pydantic>=2.0.0,<3.0.0", "rich>=14.0.0,<15.0.0", - "strands-agents>=1.0.0", + "strands-agents>=1.28.0", "strands-agents-tools>=0.1.0,<1.0.0", "typing-extensions>=4.0", "opentelemetry-api>=1.20.0", @@ -128,6 +128,7 @@ select = [ [tool.ruff.lint.per-file-ignores] "src/strands_evals/evaluators/prompt_templates/*" = ["E501"] # line-length "src/strands_evals/generators/prompt_template/*" = ["E501"] # line-length +"src/strands_evals/plugins/prompt_templates/*" = ["E501"] # line-length [tool.mypy] # Disable strict checks that cause false positives with Generic classes diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index f5c600c..37994ae 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,4 @@ -from . import evaluators, extractors, generators, providers, simulation, telemetry, types +from . import evaluators, extractors, generators, plugins, providers, simulation, telemetry, types from .case import Case from .experiment import Experiment from .simulation import ActorSimulator, UserSimulator @@ -12,6 +12,7 @@ "providers", "types", "generators", + "plugins", "simulation", "telemetry", "StrandsEvalsTelemetry", diff --git a/src/strands_evals/plugins/__init__.py b/src/strands_evals/plugins/__init__.py new file mode 100644 index 0000000..ca39e35 --- /dev/null +++ b/src/strands_evals/plugins/__init__.py @@ -0,0 +1,3 @@ +from .evaluation_plugin import EvaluationPlugin + +__all__ = ["EvaluationPlugin"] diff --git a/src/strands_evals/plugins/evaluation_plugin.py b/src/strands_evals/plugins/evaluation_plugin.py new file mode 100644 index 0000000..c59a8f9 --- /dev/null +++ b/src/strands_evals/plugins/evaluation_plugin.py @@ -0,0 +1,205 @@ +"""Plugin that evaluates agent invocations and retries with improvements on failure.""" + +import logging +from typing import Any, cast + +from pydantic import BaseModel +from strands import Agent +from strands.models import Model +from strands.plugins.plugin import Plugin + +from strands_evals.evaluators.evaluator import Evaluator +from strands_evals.plugins.prompt_templates.improvement_suggestion import ( + IMPROVEMENT_SYSTEM_PROMPT, + compose_improvement_prompt, +) +from strands_evals.types.evaluation import EvaluationData, EvaluationOutput + +logger = logging.getLogger(__name__) + + +class ImprovementSuggestion(BaseModel): + """Structured output from the improvement suggestion LLM.""" + + reasoning: str + system_prompt: str + + +class EvaluationPlugin(Plugin): + """Evaluates agent output after each invocation and retries with improved system prompts on failure.""" + + @property + def name(self) -> str: + return "strands-evals" + + def __init__( + self, + evaluators: list[Evaluator], + max_retries: int = 1, + expected_output: Any = None, + expected_trajectory: list[Any] | None = None, + model: Model | str | None = None, + ): + """Initialize the evaluation plugin. + + Args: + evaluators: Evaluators to run against agent output after each invocation. + max_retries: Maximum number of retry attempts when evaluation fails. + expected_output: Default expected output for evaluation. Can be overridden per-invocation + via ``invocation_state``. + expected_trajectory: Default expected trajectory for evaluation. Can be overridden + per-invocation via ``invocation_state``. + model: Model used by the improvement suggestion agent. Accepts a Model instance, + a model ID string, or None to use the default. + """ + self._evaluators = evaluators + self._max_retries = max_retries + self._expected_output = expected_output + self._expected_trajectory = expected_trajectory + self._model = model + self._agent: Any = None + + def init_agent(self, agent: Any) -> None: + """Wrap the agent's ``__call__`` to intercept invocations for evaluation and retry. + + Creates a dynamic subclass of the agent's class with a wrapped ``__call__`` that runs + evaluators after each invocation and retries with an improved system prompt on failure. + + Args: + agent: The agent instance whose invocations will be evaluated. + """ + self._agent = agent + original_call = agent.__class__.__call__ + plugin = self + + def wrapped_call(self_agent: Any, prompt: Any = None, **kwargs: Any) -> Any: + return plugin._invoke_with_evaluation(self_agent, original_call, prompt, **kwargs) + + wrapped_class = type( + agent.__class__.__name__, + (agent.__class__,), + {"__call__": wrapped_call}, + ) + agent.__class__ = wrapped_class + + def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, **kwargs: Any) -> Any: + """Run the agent, evaluate output, and retry with an improved system prompt on failure. + + Restores the original system prompt and messages after all attempts complete. + + Args: + agent: The agent instance being invoked. + original_call: The unwrapped ``__call__`` method. + prompt: The user prompt passed to the agent. + **kwargs: Additional keyword arguments forwarded to the agent call. + + Returns: + The result from the last agent invocation attempt. + """ + original_system_prompt = agent.system_prompt + original_messages = list(agent.messages) + invocation_state = kwargs.get("invocation_state", {}) + + for attempt in range(1 + self._max_retries): + if attempt > 0: + agent.messages = list(original_messages) + + result = original_call(agent, prompt, **kwargs) + + evaluation_data = self._build_evaluation_data(prompt, result, invocation_state) + outputs = self._run_evaluators(evaluation_data) + all_pass = all(o.test_pass for o in outputs) + + if all_pass or attempt == self._max_retries: + break + + logger.debug( + "attempt=<%s>, evaluation_pass=<%s> | evaluation failed, generating improvements", + attempt + 1, + all_pass, + ) + + expected_output = evaluation_data.expected_output + suggestion = self._suggest_improvements(prompt, str(result), outputs, agent.system_prompt, expected_output) + logger.debug( + "attempt=<%s>, reasoning=<%s> | applying improved system prompt", attempt + 1, suggestion.reasoning + ) + agent.system_prompt = suggestion.system_prompt + + agent.system_prompt = original_system_prompt + return result + + def _build_evaluation_data(self, prompt: Any, result: Any, invocation_state: dict) -> EvaluationData: + """Assemble evaluation data from the invocation context. + + Args: + prompt: The user prompt. + result: The agent's output. + invocation_state: Per-invocation overrides for expected values. + + Returns: + An EvaluationData instance ready for evaluator consumption. + """ + expected_output = invocation_state.get("expected_output", self._expected_output) + expected_trajectory = invocation_state.get("expected_trajectory", self._expected_trajectory) + + return EvaluationData( + input=prompt, + actual_output=str(result), + expected_output=expected_output, + expected_trajectory=expected_trajectory, + ) + + def _suggest_improvements( + self, + prompt: Any, + actual_output: str, + outputs: list[EvaluationOutput], + current_system_prompt: str | None, + expected_output: Any = None, + ) -> ImprovementSuggestion: + """Ask an LLM to suggest an improved system prompt based on evaluation failures. + + Args: + prompt: The original user prompt. + actual_output: The agent's output as a string. + outputs: Evaluation outputs from the failed attempt. + current_system_prompt: The agent's current system prompt. + expected_output: The expected output, if available. + + Returns: + An ImprovementSuggestion containing the reasoning and a revised system prompt. + """ + failure_reasons = [o.reason for o in outputs if not o.test_pass and o.reason] + improvement_prompt = compose_improvement_prompt( + user_prompt=str(prompt), + actual_output=actual_output, + failure_reasons=failure_reasons, + current_system_prompt=current_system_prompt, + expected_output=str(expected_output) if expected_output is not None else None, + ) + suggestion_agent = Agent(model=self._model, system_prompt=IMPROVEMENT_SYSTEM_PROMPT, callback_handler=None) + result = suggestion_agent(improvement_prompt, structured_output_model=ImprovementSuggestion) + return cast(ImprovementSuggestion, result.structured_output) + + def _run_evaluators(self, evaluation_data: EvaluationData) -> list[EvaluationOutput]: + """Run all evaluators against the given evaluation data. + + Exceptions raised by individual evaluators are caught, logged, and recorded as failures + so that a single broken evaluator does not prevent the others from running. + + Args: + evaluation_data: The data to evaluate. + + Returns: + A list of evaluation outputs from all evaluators. + """ + all_outputs: list[EvaluationOutput] = [] + for evaluator in self._evaluators: + try: + outputs = evaluator.evaluate(evaluation_data) + all_outputs.extend(outputs) + except Exception: + logger.exception("evaluator=<%s> | evaluator raised an exception", type(evaluator).__name__) + all_outputs.append(EvaluationOutput(score=0.0, test_pass=False, reason="evaluator raised an exception")) + return all_outputs diff --git a/src/strands_evals/plugins/prompt_templates/__init__.py b/src/strands_evals/plugins/prompt_templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py b/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py new file mode 100644 index 0000000..ac05a22 --- /dev/null +++ b/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py @@ -0,0 +1,37 @@ +IMPROVEMENT_SYSTEM_PROMPT = """You are an expert at analyzing AI agent failures and suggesting system prompt improvements. + +Given an agent's current system prompt, the user's request, the agent's output, and evaluation failures, suggest a modified system prompt that addresses the failures while preserving the agent's core capabilities. + +Focus on: +- Adding specific instructions that address the evaluation failure reasons +- Preserving existing useful instructions from the current system prompt +- Being concise and actionable +- Not changing the fundamental purpose of the agent + +Return the complete improved system prompt that should replace the current one.""" + + +def compose_improvement_prompt( + user_prompt: str, + actual_output: str, + failure_reasons: list[str], + current_system_prompt: str | None, + expected_output: str | None = None, +) -> str: + parts = [] + + if current_system_prompt: + parts.append(f"{current_system_prompt}") + else: + parts.append("No system prompt set") + + parts.append(f"{user_prompt}") + parts.append(f"{actual_output}") + + if expected_output is not None: + parts.append(f"{expected_output}") + + reasons_text = "\n".join(f"- {reason}" for reason in failure_reasons) + parts.append(f"\n{reasons_text}\n") + + return "\n\n".join(parts) diff --git a/tests/strands_evals/plugins/__init__.py b/tests/strands_evals/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strands_evals/plugins/test_evaluation_plugin.py b/tests/strands_evals/plugins/test_evaluation_plugin.py new file mode 100644 index 0000000..9533fe6 --- /dev/null +++ b/tests/strands_evals/plugins/test_evaluation_plugin.py @@ -0,0 +1,480 @@ +from unittest.mock import Mock, patch + +import pytest + +from strands_evals.plugins import EvaluationPlugin +from strands_evals.plugins.evaluation_plugin import ImprovementSuggestion +from strands_evals.types import EvaluationData, EvaluationOutput + + +@pytest.fixture +def mock_evaluator(): + evaluator = Mock() + return evaluator + + +class FakeAgent: + """Minimal agent class for testing __class__ swap without real Agent internals.""" + + def __init__(self): + self.system_prompt = "original prompt" + self.messages: list = [] + self._result = Mock() + self._result.__str__ = Mock(return_value="mock output") + self.call_count = 0 + + def __call__(self, prompt=None, **kwargs): + self.call_count += 1 + return self._result + + +@pytest.fixture +def mock_agent(): + return FakeAgent() + + +def test_plugin_name(mock_evaluator): + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + assert plugin.name == "strands-evals" + + +def test_init_with_defaults(mock_evaluator): + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + assert plugin._evaluators == [mock_evaluator] + assert plugin._max_retries == 1 + assert plugin._expected_output is None + assert plugin._expected_trajectory is None + assert plugin._model is None + + +def test_init_with_custom_values(mock_evaluator): + plugin = EvaluationPlugin( + evaluators=[mock_evaluator], + max_retries=3, + expected_output="expected", + expected_trajectory=["step1", "step2"], + model="us.anthropic.claude-sonnet-4-20250514-v1:0", + ) + assert plugin._evaluators == [mock_evaluator] + assert plugin._max_retries == 3 + assert plugin._expected_output == "expected" + assert plugin._expected_trajectory == ["step1", "step2"] + assert plugin._model == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + +# --- Step 2: init_agent + __class__ swap --- + + +def test_init_agent_wraps_call(mock_evaluator, mock_agent): + """After init_agent, agent.__class__ should be a subclass of the original.""" + original_class = mock_agent.__class__ + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + plugin.init_agent(mock_agent) + assert mock_agent.__class__ is not original_class + assert issubclass(mock_agent.__class__, original_class) + + +def test_init_agent_preserves_isinstance(mock_evaluator, mock_agent): + """isinstance(agent, OriginalClass) should still be True after wrapping.""" + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + plugin.init_agent(mock_agent) + assert isinstance(mock_agent, FakeAgent) + + +def test_wrapped_call_invokes_original(mock_evaluator, mock_agent): + """Calling agent(prompt) after wrapping should still invoke the original agent logic.""" + mock_evaluator.evaluate.return_value = [Mock(test_pass=True)] + plugin = EvaluationPlugin(evaluators=[mock_evaluator], max_retries=0) + plugin.init_agent(mock_agent) + + result = mock_agent("test prompt") + + assert result is not None + + +# --- Step 3: Evaluation execution on invocation --- + + +def test_evaluators_run_after_invocation(mock_agent): + """Evaluators should be called with EvaluationData after invocation.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True, reason="good")] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=0) + plugin.init_agent(mock_agent) + + mock_agent("What is 2+2?") + + evaluator.evaluate.assert_called_once() + eval_data = evaluator.evaluate.call_args[0][0] + assert isinstance(eval_data, EvaluationData) + assert eval_data.input == "What is 2+2?" + assert eval_data.actual_output == str(mock_agent._result) + + +def test_evaluation_data_uses_constructor_expected(mock_agent): + """EvaluationData should use expected values from constructor.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin( + evaluators=[evaluator], + max_retries=0, + expected_output="4", + expected_trajectory=["calculator"], + ) + plugin.init_agent(mock_agent) + + mock_agent("What is 2+2?") + + eval_data = evaluator.evaluate.call_args[0][0] + assert eval_data.expected_output == "4" + assert eval_data.expected_trajectory == ["calculator"] + + +def test_evaluation_data_uses_invocation_state_expected(mock_agent): + """invocation_state expected values should override constructor values.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin( + evaluators=[evaluator], + max_retries=0, + expected_output="constructor_value", + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt", invocation_state={"expected_output": "state_value"}) + + eval_data = evaluator.evaluate.call_args[0][0] + assert eval_data.expected_output == "state_value" + + +def test_passing_evaluation_returns_immediately(mock_agent): + """When evaluations pass, result should be returned without retry.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=3) + plugin.init_agent(mock_agent) + + result = mock_agent("prompt") + + assert result is mock_agent._result + evaluator.evaluate.assert_called_once() + + +# --- Step 4: Retry on failure --- + + +def test_retry_on_failure(mock_agent): + """Agent should be re-invoked when evaluation fails.""" + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="bad")], + [EvaluationOutput(score=1.0, test_pass=True, reason="good")], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 2 + assert evaluator.evaluate.call_count == 2 + + +def test_max_retries_respected(mock_agent): + """Should stop retrying after max_retries attempts.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False, reason="always bad")] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=2) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + # 1 initial + 2 retries = 3 total calls + assert mock_agent.call_count == 3 + assert evaluator.evaluate.call_count == 3 + + +def test_max_retries_zero_no_retry(mock_agent): + """No retry when max_retries=0.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=0) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 1 + assert evaluator.evaluate.call_count == 1 + + +def test_messages_reset_between_retries(mock_agent): + """Agent messages should be restored to pre-invocation state before each retry.""" + messages_during_calls = [] + + original_messages = ["pre-existing"] + mock_agent.messages = list(original_messages) + + original_call = mock_agent.__class__.__call__ + + def tracking_call(self, prompt=None, **kwargs): + messages_during_calls.append(list(self.messages)) + self.messages.append(f"response-{self.call_count}") + return original_call(self, prompt, **kwargs) + + mock_agent.__class__.__call__ = tracking_call + + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + # Both calls should start with the original messages + assert messages_during_calls[0] == original_messages + assert messages_during_calls[1] == original_messages + + +def test_system_prompt_restored_after_all_attempts(mock_agent): + """Original system prompt should be restored after retries complete.""" + original_prompt = mock_agent.system_prompt + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.system_prompt == original_prompt + + +# --- Step 5: Improvement suggestion generation --- + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_called_on_failure(mock_agent_class, mock_agent): + """LLM should be called to generate suggestions when evaluation fails.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion( + reasoning="Output was wrong", system_prompt="Be more careful" + ) + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="incorrect")], + [EvaluationOutput(score=1.0, test_pass=True, reason="correct")], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1, model="test-model") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + mock_agent_class.assert_called_once() + mock_suggestion_agent.assert_called_once() + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_prompt_contains_failures(mock_agent_class, mock_agent): + """Improvement prompt should include evaluation failure reasons.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="Needs fix", system_prompt="improved") + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="answer was factually wrong")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + prompt_arg = mock_suggestion_agent.call_args[0][0] + assert "answer was factually wrong" in prompt_arg + assert "original prompt" in prompt_arg + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_prompt_contains_expected_output(mock_agent_class, mock_agent): + """Improvement prompt should include expected_output when available.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="fix", system_prompt="improved") + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="does not match")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1, expected_output="Paris") + plugin.init_agent(mock_agent) + + mock_agent("What is the capital of France?") + + prompt_arg = mock_suggestion_agent.call_args[0][0] + assert "Paris" in prompt_arg + assert "ExpectedOutput" in prompt_arg + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_prompt_omits_expected_output_when_none(mock_agent_class, mock_agent): + """Improvement prompt should not include ExpectedOutput tag when no expected_output is set.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="fix", system_prompt="improved") + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="bad")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + prompt_arg = mock_suggestion_agent.call_args[0][0] + assert "ExpectedOutput" not in prompt_arg + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improved_system_prompt_applied_before_retry(mock_agent_class, mock_agent): + """Agent system prompt should be updated with suggestion before retry.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion( + reasoning="Needs specificity", system_prompt="Always answer with the city name only" + ) + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + system_prompts_during_calls = [] + original_call = mock_agent.__class__.__call__ + + def tracking_call(self, prompt=None, **kwargs): + system_prompts_during_calls.append(self.system_prompt) + return original_call(self, prompt, **kwargs) + + mock_agent.__class__.__call__ = tracking_call + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="bad")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert system_prompts_during_calls[0] == "original prompt" + assert system_prompts_during_calls[1] == "Always answer with the city name only" + + +# --- Step 7: Multiple evaluators + edge cases --- + + +def test_multiple_evaluators_all_must_pass(mock_agent): + """All evaluators must pass for the invocation to be considered successful.""" + evaluator1 = Mock() + evaluator1.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + evaluator2 = Mock() + evaluator2.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=0) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + evaluator1.evaluate.assert_called_once() + evaluator2.evaluate.assert_called_once() + assert mock_agent.call_count == 1 + + +def test_partial_evaluator_failure_triggers_retry(mock_agent): + """If any evaluator fails, retry should be triggered.""" + evaluator1 = Mock() + evaluator1.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + evaluator2 = Mock() + evaluator2.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="failed")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=1) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved") + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 2 + + +def test_evaluator_exception_recorded_as_failure(mock_agent): + """Evaluator exceptions should be caught and treated as failures.""" + evaluator = Mock() + evaluator.evaluate.side_effect = [ + RuntimeError("evaluator crashed"), + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved") + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 2 + + +def test_returns_final_attempt_result(mock_agent): + """Should return the result from the final attempt.""" + results = [Mock(__str__=Mock(return_value="first")), Mock(__str__=Mock(return_value="second"))] + call_idx = [0] + + original_call = mock_agent.__class__.__call__ + + def multi_result_call(self, prompt=None, **kwargs): + idx = call_idx[0] + call_idx[0] += 1 + original_call(self, prompt, **kwargs) + return results[idx] + + mock_agent.__class__.__call__ = multi_result_call + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False)], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved") + ) + plugin.init_agent(mock_agent) + + result = mock_agent("prompt") + + assert result is results[1]