generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 23
feat: Add EvaluationPlugin for agent invocation evaluation and retry #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
afarntrog
wants to merge
4
commits into
strands-agents:main
Choose a base branch
from
afarntrog:wip/evaluation-plugin
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
20f2b36
feat: add EvaluationPlugin for agent invocation evaluation and retry
afarntrog 2e8d943
rm print
afarntrog 9aaa441
feat: add docstrings and return structured ImprovementSuggestion type
afarntrog e2fa123
feat: add plugins module and bump strands-agents to >=1.28.0
afarntrog File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .evaluation_plugin import EvaluationPlugin | ||
|
|
||
| __all__ = ["EvaluationPlugin"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| """Plugin that evaluates agent invocations and retries with improvements on failure.""" | ||
|
|
||
| import logging | ||
| from typing import Any, cast | ||
|
|
||
| from pydantic import BaseModel | ||
| from strands import Agent | ||
| from strands.models import Model | ||
| from strands.plugins.plugin import Plugin | ||
|
|
||
| from strands_evals.evaluators.evaluator import Evaluator | ||
| from strands_evals.plugins.prompt_templates.improvement_suggestion import ( | ||
| IMPROVEMENT_SYSTEM_PROMPT, | ||
| compose_improvement_prompt, | ||
| ) | ||
| from strands_evals.types.evaluation import EvaluationData, EvaluationOutput | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ImprovementSuggestion(BaseModel): | ||
| """Structured output from the improvement suggestion LLM.""" | ||
|
|
||
| reasoning: str | ||
| system_prompt: str | ||
|
|
||
|
|
||
| class EvaluationPlugin(Plugin): | ||
| """Evaluates agent output after each invocation and retries with improved system prompts on failure.""" | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return "strands-evals" | ||
|
|
||
| def __init__( | ||
| self, | ||
| evaluators: list[Evaluator], | ||
| max_retries: int = 1, | ||
| expected_output: Any = None, | ||
| expected_trajectory: list[Any] | None = None, | ||
| model: Model | str | None = None, | ||
| ): | ||
| """Initialize the evaluation plugin. | ||
|
|
||
| Args: | ||
| evaluators: Evaluators to run against agent output after each invocation. | ||
| max_retries: Maximum number of retry attempts when evaluation fails. | ||
| expected_output: Default expected output for evaluation. Can be overridden per-invocation | ||
| via ``invocation_state``. | ||
| expected_trajectory: Default expected trajectory for evaluation. Can be overridden | ||
| per-invocation via ``invocation_state``. | ||
| model: Model used by the improvement suggestion agent. Accepts a Model instance, | ||
| a model ID string, or None to use the default. | ||
| """ | ||
| self._evaluators = evaluators | ||
| self._max_retries = max_retries | ||
| self._expected_output = expected_output | ||
| self._expected_trajectory = expected_trajectory | ||
| self._model = model | ||
| self._agent: Any = None | ||
|
|
||
| def init_agent(self, agent: Any) -> None: | ||
| """Wrap the agent's ``__call__`` to intercept invocations for evaluation and retry. | ||
|
|
||
| Creates a dynamic subclass of the agent's class with a wrapped ``__call__`` that runs | ||
| evaluators after each invocation and retries with an improved system prompt on failure. | ||
|
|
||
| Args: | ||
| agent: The agent instance whose invocations will be evaluated. | ||
| """ | ||
| self._agent = agent | ||
| original_call = agent.__class__.__call__ | ||
| plugin = self | ||
|
|
||
| def wrapped_call(self_agent: Any, prompt: Any = None, **kwargs: Any) -> Any: | ||
| return plugin._invoke_with_evaluation(self_agent, original_call, prompt, **kwargs) | ||
|
|
||
| wrapped_class = type( | ||
| agent.__class__.__name__, | ||
| (agent.__class__,), | ||
| {"__call__": wrapped_call}, | ||
| ) | ||
| agent.__class__ = wrapped_class | ||
|
|
||
| def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, **kwargs: Any) -> Any: | ||
| """Run the agent, evaluate output, and retry with an improved system prompt on failure. | ||
|
|
||
| Restores the original system prompt and messages after all attempts complete. | ||
|
|
||
| Args: | ||
| agent: The agent instance being invoked. | ||
| original_call: The unwrapped ``__call__`` method. | ||
| prompt: The user prompt passed to the agent. | ||
| **kwargs: Additional keyword arguments forwarded to the agent call. | ||
|
|
||
| Returns: | ||
| The result from the last agent invocation attempt. | ||
| """ | ||
| original_system_prompt = agent.system_prompt | ||
| original_messages = list(agent.messages) | ||
| invocation_state = kwargs.get("invocation_state", {}) | ||
|
|
||
| for attempt in range(1 + self._max_retries): | ||
| if attempt > 0: | ||
| agent.messages = list(original_messages) | ||
|
|
||
| result = original_call(agent, prompt, **kwargs) | ||
|
|
||
| evaluation_data = self._build_evaluation_data(prompt, result, invocation_state) | ||
| outputs = self._run_evaluators(evaluation_data) | ||
| all_pass = all(o.test_pass for o in outputs) | ||
|
|
||
| if all_pass or attempt == self._max_retries: | ||
| break | ||
|
|
||
| logger.debug( | ||
| "attempt=<%s>, evaluation_pass=<%s> | evaluation failed, generating improvements", | ||
| attempt + 1, | ||
| all_pass, | ||
| ) | ||
|
|
||
| expected_output = evaluation_data.expected_output | ||
| suggestion = self._suggest_improvements(prompt, str(result), outputs, agent.system_prompt, expected_output) | ||
| logger.debug( | ||
| "attempt=<%s>, reasoning=<%s> | applying improved system prompt", attempt + 1, suggestion.reasoning | ||
| ) | ||
| agent.system_prompt = suggestion.system_prompt | ||
|
|
||
| agent.system_prompt = original_system_prompt | ||
| return result | ||
|
|
||
| def _build_evaluation_data(self, prompt: Any, result: Any, invocation_state: dict) -> EvaluationData: | ||
| """Assemble evaluation data from the invocation context. | ||
|
|
||
| Args: | ||
| prompt: The user prompt. | ||
| result: The agent's output. | ||
| invocation_state: Per-invocation overrides for expected values. | ||
|
|
||
| Returns: | ||
| An EvaluationData instance ready for evaluator consumption. | ||
| """ | ||
| expected_output = invocation_state.get("expected_output", self._expected_output) | ||
| expected_trajectory = invocation_state.get("expected_trajectory", self._expected_trajectory) | ||
|
|
||
| return EvaluationData( | ||
| input=prompt, | ||
| actual_output=str(result), | ||
| expected_output=expected_output, | ||
| expected_trajectory=expected_trajectory, | ||
| ) | ||
|
|
||
| def _suggest_improvements( | ||
| self, | ||
| prompt: Any, | ||
| actual_output: str, | ||
| outputs: list[EvaluationOutput], | ||
| current_system_prompt: str | None, | ||
| expected_output: Any = None, | ||
| ) -> ImprovementSuggestion: | ||
| """Ask an LLM to suggest an improved system prompt based on evaluation failures. | ||
|
|
||
| Args: | ||
| prompt: The original user prompt. | ||
| actual_output: The agent's output as a string. | ||
| outputs: Evaluation outputs from the failed attempt. | ||
| current_system_prompt: The agent's current system prompt. | ||
| expected_output: The expected output, if available. | ||
|
|
||
| Returns: | ||
| An ImprovementSuggestion containing the reasoning and a revised system prompt. | ||
| """ | ||
| failure_reasons = [o.reason for o in outputs if not o.test_pass and o.reason] | ||
| improvement_prompt = compose_improvement_prompt( | ||
| user_prompt=str(prompt), | ||
| actual_output=actual_output, | ||
| failure_reasons=failure_reasons, | ||
| current_system_prompt=current_system_prompt, | ||
| expected_output=str(expected_output) if expected_output is not None else None, | ||
| ) | ||
| suggestion_agent = Agent(model=self._model, system_prompt=IMPROVEMENT_SYSTEM_PROMPT, callback_handler=None) | ||
| result = suggestion_agent(improvement_prompt, structured_output_model=ImprovementSuggestion) | ||
| return cast(ImprovementSuggestion, result.structured_output) | ||
|
|
||
| def _run_evaluators(self, evaluation_data: EvaluationData) -> list[EvaluationOutput]: | ||
| """Run all evaluators against the given evaluation data. | ||
|
|
||
| Exceptions raised by individual evaluators are caught, logged, and recorded as failures | ||
| so that a single broken evaluator does not prevent the others from running. | ||
|
|
||
| Args: | ||
| evaluation_data: The data to evaluate. | ||
|
|
||
| Returns: | ||
| A list of evaluation outputs from all evaluators. | ||
| """ | ||
| all_outputs: list[EvaluationOutput] = [] | ||
| for evaluator in self._evaluators: | ||
| try: | ||
| outputs = evaluator.evaluate(evaluation_data) | ||
| all_outputs.extend(outputs) | ||
| except Exception: | ||
| logger.exception("evaluator=<%s> | evaluator raised an exception", type(evaluator).__name__) | ||
| all_outputs.append(EvaluationOutput(score=0.0, test_pass=False, reason="evaluator raised an exception")) | ||
| return all_outputs | ||
Empty file.
37 changes: 37 additions & 0 deletions
37
src/strands_evals/plugins/prompt_templates/improvement_suggestion.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| IMPROVEMENT_SYSTEM_PROMPT = """You are an expert at analyzing AI agent failures and suggesting system prompt improvements. | ||
|
|
||
| Given an agent's current system prompt, the user's request, the agent's output, and evaluation failures, suggest a modified system prompt that addresses the failures while preserving the agent's core capabilities. | ||
|
|
||
| Focus on: | ||
| - Adding specific instructions that address the evaluation failure reasons | ||
| - Preserving existing useful instructions from the current system prompt | ||
| - Being concise and actionable | ||
| - Not changing the fundamental purpose of the agent | ||
|
|
||
| Return the complete improved system prompt that should replace the current one.""" | ||
|
|
||
|
|
||
| def compose_improvement_prompt( | ||
| user_prompt: str, | ||
| actual_output: str, | ||
| failure_reasons: list[str], | ||
| current_system_prompt: str | None, | ||
| expected_output: str | None = None, | ||
| ) -> str: | ||
| parts = [] | ||
|
|
||
| if current_system_prompt: | ||
| parts.append(f"<CurrentSystemPrompt>{current_system_prompt}</CurrentSystemPrompt>") | ||
| else: | ||
| parts.append("<CurrentSystemPrompt>No system prompt set</CurrentSystemPrompt>") | ||
|
|
||
| parts.append(f"<UserRequest>{user_prompt}</UserRequest>") | ||
| parts.append(f"<AgentOutput>{actual_output}</AgentOutput>") | ||
|
|
||
| if expected_output is not None: | ||
| parts.append(f"<ExpectedOutput>{expected_output}</ExpectedOutput>") | ||
|
|
||
| reasons_text = "\n".join(f"- {reason}" for reason in failure_reasons) | ||
| parts.append(f"<EvaluationFailures>\n{reasons_text}\n</EvaluationFailures>") | ||
|
|
||
| return "\n\n".join(parts) |
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to do this monkey patching because hooks alone will not give us the retry pattern we need:
What EvaluationPlugin needs to do:
max_retriestimesWhat hooks give you:
BeforeInvocationEventfires before the agent runs. You can modify messages. No result yet.AfterInvocationEventfires after the agent runs. You get the result. But you're inside_run_loop'sfinally block, andstream_asyncstill holds the invocation lock. If you callagent()from here → deadlock.Hooks give you "before" and "after", not "around". There's no way to say "evaluate this result and, if it's bad, run the whole thing again" from within a hook. The retry requires calling
agent(), which requires the lock, which is held by the invocation that's firing your hook.__class__swap helps because:This replaces
__call__— every call toagent(prompt)now goes through_invoke_with_evaluation, which runs evaluators and handles retries.