diff --git a/src/strands_evals/simulation/tool_simulator.py b/src/strands_evals/simulation/tool_simulator.py index da844c5..a755611 100644 --- a/src/strands_evals/simulation/tool_simulator.py +++ b/src/strands_evals/simulation/tool_simulator.py @@ -11,6 +11,7 @@ from strands.models.model import Model from strands.tools.decorator import DecoratedFunctionTool, FunctionToolMetadata +from strands_evals.types.simulation.hook_events import PostCallHookEvent, PreCallHookEvent from strands_evals.types.simulation.tool import DefaultToolResponse, RegisteredTool from .prompt_templates.tool_response_generation import TOOL_RESPONSE_PROMPT_TEMPLATE @@ -166,6 +167,8 @@ def __init__( state_registry: StateRegistry | None = None, model: Model | str | None = None, max_tool_call_cache_size: int = 20, + pre_call_hook: Callable | None = None, + post_call_hook: Callable | None = None, ): """ Initialize a ToolSimulator instance. @@ -178,10 +181,21 @@ def __init__( Only used when creating a new StateRegistry (ignored if state_registry is provided). Older calls are automatically evicted when limit is exceeded. Default is 20. + pre_call_hook: Optional callable invoked before the LLM generates a tool response. + Receives a PreCallHookEvent with tool_name, parameters, state_key, + and previous_calls. If it returns a non-None dict, that dict is used + as the tool response (short-circuiting the LLM call) and cached in + the state registry. If it returns None, normal LLM simulation proceeds. + post_call_hook: Optional callable invoked after the LLM generates a tool response + but before it is cached. Receives a PostCallHookEvent with tool_name, + parameters, state_key, and response. Must return a (possibly modified) + response dict. """ self.model = model self.state_registry = state_registry or StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size) self._registered_tools: dict[str, RegisteredTool] = {} + self._pre_call_hook = pre_call_hook + self._post_call_hook = post_call_hook def _create_tool_wrapper(self, registered_tool: RegisteredTool): """ @@ -245,7 +259,35 @@ def _parse_simulated_response(self, result: AgentResult) -> dict[str, Any]: return response_data def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, state_key: str) -> dict[str, Any]: - """Simulate a tool invocation and return the response.""" + """Simulate a tool invocation and return the response. + + If a pre_call_hook is configured and returns a non-None dict, that dict is used + as the tool response (short-circuiting the LLM call). The response is still cached. + + If a post_call_hook is configured, it receives the LLM-generated response before + caching and may modify it. + """ + parameters = json.loads(parameters_string) + current_state = self.state_registry.get_state(state_key) + + # Pre-call hook: may short-circuit the LLM call + if self._pre_call_hook is not None: + event = PreCallHookEvent( + tool_name=registered_tool.name, + parameters=parameters, + state_key=state_key, + previous_calls=current_state.get("previous_calls", []), + ) + hook_response = self._pre_call_hook(event) + if hook_response is not None: + if not isinstance(hook_response, dict): + raise TypeError(f"pre_call_hook must return a dict or None, got {type(hook_response).__name__}") + self.state_registry.cache_tool_call( + registered_tool.name, state_key, hook_response, parameters=parameters + ) + return hook_response + + # Normal LLM simulation # Get input schema from Strands tool decorator input_schema_dict = registered_tool.function.tool_spec.get("inputSchema", {}).get("json", {}) input_schema = json.dumps(input_schema_dict, indent=2) @@ -254,8 +296,6 @@ def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, st output_schema = registered_tool.output_schema.model_json_schema() output_schema_string = json.dumps(output_schema, indent=2) - current_state = self.state_registry.get_state(state_key) - prompt = TOOL_RESPONSE_PROMPT_TEMPLATE.format( tool_name=registered_tool.name, input_schema=input_schema, @@ -268,9 +308,19 @@ def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, st response_data = self._parse_simulated_response(result) - self.state_registry.cache_tool_call( - registered_tool.name, state_key, response_data, parameters=json.loads(parameters_string) - ) + # Post-call hook: may modify the response before caching + if self._post_call_hook is not None: + event = PostCallHookEvent( + tool_name=registered_tool.name, + parameters=parameters, + state_key=state_key, + response=response_data, + ) + response_data = self._post_call_hook(event) + if not isinstance(response_data, dict): + raise TypeError(f"post_call_hook must return a dict, got {type(response_data).__name__}") + + self.state_registry.cache_tool_call(registered_tool.name, state_key, response_data, parameters=parameters) return response_data def tool( diff --git a/src/strands_evals/types/simulation/hook_events.py b/src/strands_evals/types/simulation/hook_events.py new file mode 100644 index 0000000..e7ff63a --- /dev/null +++ b/src/strands_evals/types/simulation/hook_events.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class PreCallHookEvent: + """ + Event passed to pre_call_hook before the LLM generates a tool response. + + Attributes: + tool_name: Name of the tool being called. + parameters: Parsed parameters for the tool call. + state_key: Key for the state (tool_name or share_state_id). + previous_calls: List of previous tool call records from the state registry. + """ + + tool_name: str + parameters: dict[str, Any] + state_key: str + previous_calls: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class PostCallHookEvent: + """ + Event passed to post_call_hook after the LLM generates a tool response. + + Attributes: + tool_name: Name of the tool that was called. + parameters: Parsed parameters for the tool call. + state_key: Key for the state (tool_name or share_state_id). + response: The LLM-generated response dict, which the hook may modify. + """ + + tool_name: str + parameters: dict[str, Any] + state_key: str + response: dict[str, Any] = field(default_factory=dict) diff --git a/tests/strands_evals/simulation/test_tool_simulator.py b/tests/strands_evals/simulation/test_tool_simulator.py index a700085..adee9cf 100644 --- a/tests/strands_evals/simulation/test_tool_simulator.py +++ b/tests/strands_evals/simulation/test_tool_simulator.py @@ -465,3 +465,257 @@ def test_tool_no_params() -> dict: properties = schema.get("properties", {}) # Empty model should mean no properties assert len(properties) == 0, "Tool with empty schema should have no properties" + + +def test_pre_call_hook_short_circuits_llm(mock_model): + """Test that pre_call_hook can short-circuit the LLM call by returning a dict.""" + hook_calls = [] + + def fault_injector(event): + hook_calls.append({"tool_name": event.tool_name, "parameters": event.parameters}) + return {"error": {"code": "QuotaExceeded", "retryAfterSeconds": 2}} + + simulator = ToolSimulator(model=mock_model, pre_call_hook=fault_injector) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(message: str) -> dict: + """A tool that should be intercepted.""" + pass + + # The LLM should never be called because the hook short-circuits + mock_agent_instance = MagicMock() + mock_agent_instance.return_value = MagicMock() + + with pytest.MonkeyPatch().context() as m: + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + + result = simulator.my_tool(message="hello") + + assert result == {"error": {"code": "QuotaExceeded", "retryAfterSeconds": 2}} + assert len(hook_calls) == 1 + assert hook_calls[0]["tool_name"] == "my_tool" + assert hook_calls[0]["parameters"] == {"message": "hello"} + # LLM agent should NOT have been called + assert not mock_agent_instance.called + + +def test_pre_call_hook_returns_none_proceeds_normally(mock_model): + """Test that pre_call_hook returning None lets normal LLM simulation proceed.""" + hook_calls = [] + + def passthrough_hook(event): + hook_calls.append(event.tool_name) + return None # Don't short-circuit + + simulator = ToolSimulator(model=mock_model, pre_call_hook=passthrough_hook) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(message: str) -> dict: + """A tool.""" + pass + + mock_agent_instance = MagicMock() + mock_result = MagicMock() + mock_result.__str__ = MagicMock(return_value='{"result": "llm response"}') + mock_agent_instance.return_value = mock_result + + with pytest.MonkeyPatch().context() as m: + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + + result = simulator.my_tool(message="hello") + + assert result == {"result": "llm response"} + assert len(hook_calls) == 1 + # LLM agent SHOULD have been called + assert mock_agent_instance.called + + +def test_post_call_hook_modifies_response(mock_model): + """Test that post_call_hook can modify the LLM response before caching.""" + + def response_modifier(event): + event.response["_simulated_latency_ms"] = 42 + return event.response + + simulator = ToolSimulator(model=mock_model, post_call_hook=response_modifier) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(message: str) -> dict: + """A tool.""" + pass + + mock_agent_instance = MagicMock() + mock_result = MagicMock() + mock_result.__str__ = MagicMock(return_value='{"result": "llm response"}') + mock_agent_instance.return_value = mock_result + + with pytest.MonkeyPatch().context() as m: + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + + result = simulator.my_tool(message="hello") + + assert result["result"] == "llm response" + assert result["_simulated_latency_ms"] == 42 + + # Verify the cached response also has the modification + state = simulator.get_state("my_tool") + cached = state["previous_calls"][0] + assert cached["response"]["_simulated_latency_ms"] == 42 + + +def test_pre_call_hook_response_is_cached(): + """Test that a short-circuited response from pre_call_hook is cached in state registry.""" + + def always_fail(event): + return {"error": "service_unavailable"} + + simulator = ToolSimulator(pre_call_hook=always_fail) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(x: int) -> dict: + """A tool.""" + pass + + # No need to mock Agent since LLM is never called + result = simulator.my_tool(x=1) + + assert result == {"error": "service_unavailable"} + + state = simulator.get_state("my_tool") + assert len(state["previous_calls"]) == 1 + assert state["previous_calls"][0]["response"] == {"error": "service_unavailable"} + assert state["previous_calls"][0]["tool_name"] == "my_tool" + + +def test_pre_call_hook_receives_previous_calls(): + """Test that pre_call_hook receives the accumulated call history.""" + received_histories = [] + + call_count = 0 + + def counting_hook(event): + nonlocal call_count + received_histories.append(list(event.previous_calls)) + call_count += 1 + return {"call_number": call_count} + + simulator = ToolSimulator(pre_call_hook=counting_hook) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(x: int) -> dict: + """A tool.""" + pass + + simulator.my_tool(x=1) + simulator.my_tool(x=2) + simulator.my_tool(x=3) + + # First call should see empty history + assert len(received_histories[0]) == 0 + # Second call should see 1 previous call + assert len(received_histories[1]) == 1 + # Third call should see 2 previous calls + assert len(received_histories[2]) == 2 + + +def test_both_hooks_together(mock_model): + """Test that pre_call_hook and post_call_hook work together when pre_call_hook returns None.""" + pre_calls = [] + post_calls = [] + + def pre_hook(event): + pre_calls.append(event.tool_name) + return None # Let LLM proceed + + def post_hook(event): + post_calls.append(event.tool_name) + event.response["modified"] = True + return event.response + + simulator = ToolSimulator(model=mock_model, pre_call_hook=pre_hook, post_call_hook=post_hook) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(message: str) -> dict: + """A tool.""" + pass + + mock_agent_instance = MagicMock() + mock_result = MagicMock() + mock_result.__str__ = MagicMock(return_value='{"result": "ok"}') + mock_agent_instance.return_value = mock_result + + with pytest.MonkeyPatch().context() as m: + m.setattr("strands_evals.simulation.tool_simulator.Agent", lambda **kwargs: mock_agent_instance) + + result = simulator.my_tool(message="test") + + assert result == {"result": "ok", "modified": True} + assert len(pre_calls) == 1 + assert len(post_calls) == 1 + + +def test_post_call_hook_skipped_when_pre_hook_short_circuits(): + """Test that post_call_hook is NOT called when pre_call_hook short-circuits.""" + post_calls = [] + + def pre_hook(event): + return {"error": "injected fault"} + + def post_hook(event): + post_calls.append(True) + return event.response + + simulator = ToolSimulator(pre_call_hook=pre_hook, post_call_hook=post_hook) + + @simulator.tool(output_schema=GenericOutput) + def my_tool(x: int) -> dict: + """A tool.""" + pass + + result = simulator.my_tool(x=1) + + assert result == {"error": "injected fault"} + assert len(post_calls) == 0 # post_call_hook should not have been called + + +def test_pre_call_hook_with_shared_state(): + """Test that pre_call_hook works correctly with shared state between tools.""" + call_counts: Dict[str, int] = {} + + def rate_limiter(event): + call_counts[event.tool_name] = call_counts.get(event.tool_name, 0) + 1 + if call_counts[event.tool_name] > 2: + return {"error": {"code": "RateLimited", "tool": event.tool_name}} + return {"result": f"{event.tool_name} ok"} + + simulator = ToolSimulator(pre_call_hook=rate_limiter) + + @simulator.tool(output_schema=GenericOutput, share_state_id="shared") + def tool_a(x: int) -> dict: + """Tool A.""" + pass + + @simulator.tool(output_schema=GenericOutput, share_state_id="shared") + def tool_b(x: int) -> dict: + """Tool B.""" + pass + + # Both tools share state, but rate limiter tracks per-tool + assert simulator.tool_a(x=1) == {"result": "tool_a ok"} + assert simulator.tool_a(x=2) == {"result": "tool_a ok"} + assert simulator.tool_a(x=3) == {"error": {"code": "RateLimited", "tool": "tool_a"}} + + # tool_b has its own counter + assert simulator.tool_b(x=1) == {"result": "tool_b ok"} + + # Shared state should have all calls + state = simulator.get_state("shared") + assert len(state["previous_calls"]) == 4 + + +def test_init_without_hooks(): + """Test that ToolSimulator works normally without hooks (backward compatibility).""" + simulator = ToolSimulator() + + assert simulator._pre_call_hook is None + assert simulator._post_call_hook is None