Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions src/openjarvis/agents/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import concurrent.futures
import re
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional

from openjarvis.agents._stubs import AgentContext, AgentResult, ToolUsingAgent
from openjarvis.core.events import EventBus
Expand Down Expand Up @@ -60,6 +60,7 @@ def __init__(
parallel_tools: bool = True,
interactive: bool = False,
confirm_callback=None,
before_tool_call: Optional[Callable[[str, dict], bool]] = None,
) -> None:
super().__init__(
engine,
Expand All @@ -75,6 +76,7 @@ def __init__(
self._mode = mode
self._system_prompt = system_prompt
self._parallel_tools = parallel_tools
self._before_tool_call = before_tool_call

def run(
self,
Expand All @@ -86,6 +88,36 @@ def run(
return self._run_structured(input, context, **kwargs)
return self._run_function_calling(input, context, **kwargs)

# ------------------------------------------------------------------
# Governance hook
# ------------------------------------------------------------------

def _check_tool_allowed(self, tc: "ToolCall") -> "Optional[ToolResult]":
"""Call before_tool_call hook if set.

Returns None to allow execution, or a denial ToolResult to inject
instead of running the tool.
"""
if self._before_tool_call is None:
return None
import json

try:
tool_args = json.loads(tc.arguments) if tc.arguments else {}
except (json.JSONDecodeError, TypeError):
tool_args = {}
allowed = self._before_tool_call(tc.name, tool_args)
if allowed:
return None
return ToolResult(
tool_name=tc.name,
content=(
f"[Governance] Tool '{tc.name}' was not approved. "
"Adjust your plan and try a different approach."
),
success=False,
)

# ------------------------------------------------------------------
# Structured mode (THOUGHT/TOOL/INPUT/FINAL_ANSWER)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -142,7 +174,11 @@ def _run_structured(
name=parsed["tool"],
arguments=parsed["input"] or "{}",
)
tool_result = self._executor.execute(tool_call)
denial = self._check_tool_allowed(tool_call)
if denial is not None:
tool_result = denial
else:
tool_result = self._executor.execute(tool_call)
all_tool_results.append(tool_result)

observation = f"Observation: {tool_result.content}"
Expand Down Expand Up @@ -284,6 +320,9 @@ def _run_function_calling(
if self._parallel_tools and len(tool_calls) > 1:
# Parallel execution
def _exec_tool(tc: ToolCall) -> tuple:
denial = self._check_tool_allowed(tc)
if denial is not None:
return tc, denial
if self._loop_guard:
verdict = self._loop_guard.check_call(
tc.name,
Expand Down Expand Up @@ -321,6 +360,20 @@ def _exec_tool(tc: ToolCall) -> tuple:
else:
# Sequential execution
for tc in tool_calls:
# Governance hook check before execution
denial = self._check_tool_allowed(tc)
if denial is not None:
all_tool_results.append(denial)
messages.append(
Message(
role=Role.TOOL,
content=denial.content,
tool_call_id=tc.id,
name=tc.name,
)
)
continue

# Loop guard check before execution
if self._loop_guard:
verdict = self._loop_guard.check_call(
Expand Down
147 changes: 147 additions & 0 deletions tests/agents/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,150 @@ def test_single_tool_call_no_parallel(self):
)
result = agent.run("What is 2+2?")
assert result.content == "The answer is 4."


class TestOrchestratorGovernanceHook:
"""Tests for the optional before_tool_call governance hook."""

def test_no_hook_runs_tool_normally(self):
"""Without a hook, tool execution is unchanged."""
engine = _make_engine_with_tool_call()
agent = OrchestratorAgent(
engine,
"test-model",
tools=[_CalculatorStub()],
)
result = agent.run("What is 2+2?")
assert result.content == "The answer is 4."
assert result.tool_results[0].success is True
assert result.tool_results[0].content == "4"

def test_hook_returning_true_allows_tool(self):
"""Hook returning True allows the tool to execute normally."""
engine = _make_engine_with_tool_call()
agent = OrchestratorAgent(
engine,
"test-model",
tools=[_CalculatorStub()],
before_tool_call=lambda name, args: True,
)
result = agent.run("What is 2+2?")
assert result.tool_results[0].success is True
assert result.tool_results[0].content == "4"

def test_hook_returning_false_blocks_tool(self):
"""Hook returning False injects a denial; the tool is never executed."""
executed = []

class _TrackedCalculator(_CalculatorStub):
def execute(self, **params) -> ToolResult:
executed.append(params)
return super().execute(**params)

engine = _make_engine_with_tool_call()
agent = OrchestratorAgent(
engine,
"test-model",
tools=[_TrackedCalculator()],
before_tool_call=lambda name, args: False,
)
result = agent.run("What is 2+2?")
assert executed == [], "tool must not be called when hook returns False"
assert result.tool_results[0].success is False
assert "[Governance]" in result.tool_results[0].content

def test_hook_receives_correct_tool_name_and_args(self):
"""The hook is called with the exact tool name and parsed argument dict."""
calls: list[tuple] = []

def _hook(name: str, args: dict) -> bool:
calls.append((name, args))
return True

engine = _make_engine_with_tool_call(
tool_name="calculator",
arguments='{"expression":"3*7"}',
)
agent = OrchestratorAgent(
engine,
"test-model",
tools=[_CalculatorStub()],
before_tool_call=_hook,
)
agent.run("Calculate")
assert len(calls) == 1
assert calls[0][0] == "calculator"
assert calls[0][1] == {"expression": "3*7"}

def test_hook_denial_in_parallel_mode(self):
"""Hook returning False blocks tools in the parallel execution path."""
executed = []

class _TrackedCalculator(_CalculatorStub):
def execute(self, **params) -> ToolResult:
executed.append(params)
return super().execute(**params)

engine = _make_engine_multi_tool()
agent = OrchestratorAgent(
engine,
"test-model",
tools=[_TrackedCalculator(), _ThinkStub()],
parallel_tools=True,
before_tool_call=lambda name, args: False,
)
result = agent.run("Think and calculate.")
assert executed == [], "no tool must execute when hook denies"
assert len(result.tool_results) == 2
assert all(not tr.success for tr in result.tool_results)
assert all("[Governance]" in tr.content for tr in result.tool_results)

def test_hook_denial_in_structured_mode(self):
"""Hook returning False blocks a tool in structured (ReAct) mode."""
executed = []

class _TrackedCalculator(_CalculatorStub):
def execute(self, **params) -> ToolResult:
executed.append(params)
return super().execute(**params)

engine = MagicMock()
engine.engine_id = "mock"
engine.generate.side_effect = [
{
"content": (
"THOUGHT: Need to calculate.\n"
"TOOL: calculator\n"
'INPUT: {"expression":"2+2"}'
),
"usage": {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
},
"model": "test-model",
"finish_reason": "stop",
},
{
"content": "THOUGHT: Blocked.\nFINAL_ANSWER: Could not calculate.",
"usage": {
"prompt_tokens": 20,
"completion_tokens": 10,
"total_tokens": 30,
},
"model": "test-model",
"finish_reason": "stop",
},
]
agent = OrchestratorAgent(
engine,
"test-model",
tools=[_TrackedCalculator()],
mode="structured",
before_tool_call=lambda name, args: False,
)
result = agent.run("What is 2+2?")
assert executed == [], "tool must not execute when hook denies"
assert len(result.tool_results) == 1
assert result.tool_results[0].success is False
assert "[Governance]" in result.tool_results[0].content