From f5955c681a2bc3b5aa0bce285c659a18de305f50 Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 21:22:02 +0100 Subject: [PATCH 1/7] feat(reliability): report iteration/budget limits to user (PRD-141 US-009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously two limits failed silently: a chat agent that hit CHATBOT_MAX_TOOL_ITERATIONS just returned its best-effort answer with no signal, and a mission that blew past its token budget only logged a diagnostic event. Both now surface a user-facing reason. - streaming.py: add format_aisdk_limit_reached() — a typed limit_reached AI SDK data frame (limit/value/message) consistent with the other format_aisdk_* events and safe for the text/plain StreamingResponse. - chatbot service _run_tool_loop: emit that frame at the max-iterations cap before forcing the final response, telling the user a cap was hit and how an admin can raise it. - coordinator _record_task_result: enrich the existing 1.5x-overage BUDGET_WARNING with limit_type/spent/limit/message (kept the diagnostic tokens_used/ratio fields). Purely additive — no pause/flow change. - tests: test_us009_limit_reporting.py covers the SSE frame contract and the coordinator emit (over/under/no-budget cases). Loads streaming.py in isolation so the formatter test needs no chatbot package deps. --- orchestrator/consumers/chatbot/service.py | 10 + orchestrator/consumers/chatbot/streaming.py | 9 + orchestrator/services/coordinator_service.py | 12 ++ .../tests/test_us009_limit_reporting.py | 187 ++++++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 orchestrator/tests/test_us009_limit_reporting.py diff --git a/orchestrator/consumers/chatbot/service.py b/orchestrator/consumers/chatbot/service.py index fa6283d35..29f3d5618 100644 --- a/orchestrator/consumers/chatbot/service.py +++ b/orchestrator/consumers/chatbot/service.py @@ -1482,6 +1482,16 @@ async def _run_tool_loop( # Max iterations reached if iteration >= max_iterations: logger.warning(f"Max tool iterations ({max_iterations}) reached. Forcing final response.") + yield self.streaming_handler.format_aisdk_limit_reached( + limit="max_tool_iterations", + value=max_iterations, + message=( + f"I reached the maximum of {max_iterations} tool steps for a " + "single response, so I'm answering with what I have so far. " + "An admin can raise this via the CHATBOT_MAX_TOOL_ITERATIONS " + "setting (or the workspace power-mode caps)." + ), + ) final = await agent_runtime.llm_manager.generate_response( messages=llm_messages, tools=None, ) diff --git a/orchestrator/consumers/chatbot/streaming.py b/orchestrator/consumers/chatbot/streaming.py index 12c7d0d69..9a7c60309 100644 --- a/orchestrator/consumers/chatbot/streaming.py +++ b/orchestrator/consumers/chatbot/streaming.py @@ -114,6 +114,15 @@ def format_aisdk_data(self, event_type: str, data: Dict[str, Any] = None) -> str payload["data"] = data return f'd:{json.dumps(payload)}\n' + def format_aisdk_limit_reached(self, limit: str, value: int, message: str) -> str: + """Format a limit_reached event so the user is told an agent stopped + because it hit a cap (instead of silently bailing). Carries limit/value + under the AI SDK data envelope like every other data event.""" + return self.format_aisdk_data( + "limit_reached", + {"limit": limit, "value": value, "message": message}, + ) + def format_aisdk_chat_id(self, chat_id: str) -> str: """Format chat-id data event.""" return f'd:{{"type":"chat-id","chatId":"{chat_id}"}}\n' diff --git a/orchestrator/services/coordinator_service.py b/orchestrator/services/coordinator_service.py index 9c5c3cf68..e7eda5e2e 100644 --- a/orchestrator/services/coordinator_service.py +++ b/orchestrator/services/coordinator_service.py @@ -1481,6 +1481,18 @@ async def _record_task_result( actor_type=ActorType.COORDINATOR, actor_id="coordinator", payload={ + # US-009 user-facing fields: tell the user what limit was + # hit and how to raise it, before any budget-driven pause. + "limit_type": "mission_token_budget", + "spent": run.tokens_used, + "limit": run.token_budget_estimate, + "message": ( + f"This mission has used {run.tokens_used:,} tokens, over its " + f"estimated budget of {run.token_budget_estimate:,}. It will keep " + "running; an admin can raise the mission token budget or the " + "power-mode caps in Settings > Coordination." + ), + # Established diagnostic fields (parity with reconciler emit). "tokens_used": run.tokens_used, "token_budget_estimate": run.token_budget_estimate, "ratio": round(run.tokens_used / run.token_budget_estimate, 2), diff --git a/orchestrator/tests/test_us009_limit_reporting.py b/orchestrator/tests/test_us009_limit_reporting.py new file mode 100644 index 000000000..78a81005f --- /dev/null +++ b/orchestrator/tests/test_us009_limit_reporting.py @@ -0,0 +1,187 @@ +""" +PRD-141 US-009: Report iteration/budget limits to the user. +============================================================ + +Two independent user-facing signals are proven here: + +1. Chatbot side — when ``_run_tool_loop`` exhausts ``CHATBOT_MAX_TOOL_ITERATIONS`` + it now emits a ``limit_reached`` SSE frame so the user is told the agent + stopped on a cap (instead of silently answering). We assert the formatter + produces the AI SDK data envelope the stream consumer expects. + +2. Coordinator side — when a mission crosses 1.5x its token budget, + ``_record_task_result`` emits a ``BUDGET_WARNING`` carrying the user-facing + ``limit_type/spent/limit/message`` fields (the mission keeps running; this + is a purely additive signal, no pause). +""" +import importlib.util +import json +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Ensure orchestrator package is importable +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + +from core.models.orchestration_enums import EventType, TaskState +from services.coordinator_service import CoordinatorService + + +def _load_streaming_handler_class(): + """Load StreamingHandler from the leaf module file directly. + + Importing ``consumers.chatbot.streaming`` the normal way runs + ``consumers/chatbot/__init__.py``, which eagerly imports the full chatbot + service (camelot/PDF + DB deps) — none of which the formatter needs. The + module file itself only uses stdlib, so we load it in isolation. + """ + streaming_path = _orchestrator_root / "consumers" / "chatbot" / "streaming.py" + spec = importlib.util.spec_from_file_location( + "_us009_streaming", streaming_path, + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.StreamingHandler + + +StreamingHandler = _load_streaming_handler_class() + + +# --------------------------------------------------------------------------- +# 1. Chatbot iteration-limit reporting +# --------------------------------------------------------------------------- + +def _parse_aisdk_data_frame(frame: str) -> dict: + """Strip the ``d:`` prefix + trailing newline and JSON-decode the payload.""" + assert frame.startswith("d:"), f"expected AI SDK data frame, got: {frame!r}" + assert frame.endswith("\n"), f"AI SDK frame must end with newline, got: {frame!r}" + return json.loads(frame[2:].strip()) + + +def test_iteration_limit_reports_to_user(): + """The max-iterations cap surfaces as a typed limit_reached SSE frame.""" + handler = StreamingHandler() + message = ( + "I reached the maximum of 10 tool steps for a single response, so I'm " + "answering with what I have so far." + ) + + frame = handler.format_aisdk_limit_reached( + limit="max_tool_iterations", + value=10, + message=message, + ) + + payload = _parse_aisdk_data_frame(frame) + assert payload["type"] == "limit_reached" + data = payload["data"] + assert data["limit"] == "max_tool_iterations" + assert data["value"] == 10 + assert data["message"] == message + + +def test_iteration_limit_frame_is_stream_safe_string(): + """Consumer feeds a text/plain StreamingResponse — frame must be a str.""" + handler = StreamingHandler() + frame = handler.format_aisdk_limit_reached( + limit="max_tool_iterations", value=50, message="hit the cap", + ) + assert isinstance(frame, str) + + +# --------------------------------------------------------------------------- +# 2. Coordinator budget-overage reporting +# --------------------------------------------------------------------------- + +def _make_run(*, token_budget_estimate, tokens_used=0): + run = MagicMock() + run.id = "run-123" + run.token_budget_estimate = token_budget_estimate + run.tokens_used = tokens_used # real int so the += arithmetic works + return run + + +def _make_task(): + task = MagicMock() + task.id = "task-1" + task.title = "Some task" + task.state = TaskState.RUNNING.value # anything that is NOT FAILED + return task + + +async def _call_record_task_result(run, result): + """Invoke the unbound method with a mock ``self`` + targeted patches. + + ``_record_task_result`` does heavy work before the budget block + (record_task_completion, mission-event dispatch, field injection); we + patch those out so the test isolates the budget-warning emit. + """ + mock_self = MagicMock() + mock_self._inject_task_output_into_field = AsyncMock() + + db = MagicMock() + task = _make_task() + + with patch("services.coordinator_service.MissionDispatcher"), \ + patch("services.coordinator_service._dispatch_mission_event", new=AsyncMock()), \ + patch("services.coordinator_service.emit_event") as mock_emit: + await CoordinatorService._record_task_result( + mock_self, db, run, task, agent_id=1, result=result, + ) + return mock_emit + + +def _budget_warnings(mock_emit): + return [ + c for c in mock_emit.call_args_list + if c.kwargs.get("event_type") == EventType.BUDGET_WARNING + ] + + +@pytest.mark.asyncio +async def test_budget_exceeded_emits_event(): + """Crossing 1.5x budget emits BUDGET_WARNING with user-facing fields.""" + run = _make_run(token_budget_estimate=1000, tokens_used=0) + result = {"status": "success", "execution": {"tokens_used": 2000}} + + mock_emit = await _call_record_task_result(run, result) + + warnings = _budget_warnings(mock_emit) + assert len(warnings) == 1, "expected exactly one BUDGET_WARNING" + + payload = warnings[0].kwargs["payload"] + # User-facing fields (the US-009 increment) + assert payload["limit_type"] == "mission_token_budget" + assert payload["spent"] == 2000 + assert payload["limit"] == 1000 + assert "tokens" in payload["message"].lower() + # Established diagnostic fields retained + assert payload["tokens_used"] == 2000 + assert payload["token_budget_estimate"] == 1000 + assert payload["ratio"] == 2.0 + + +@pytest.mark.asyncio +async def test_under_budget_does_not_emit_warning(): + """A task that stays under 1.5x budget emits no BUDGET_WARNING.""" + run = _make_run(token_budget_estimate=10000, tokens_used=0) + result = {"status": "success", "execution": {"tokens_used": 2000}} + + mock_emit = await _call_record_task_result(run, result) + + assert _budget_warnings(mock_emit) == [] + + +@pytest.mark.asyncio +async def test_no_budget_estimate_does_not_emit_warning(): + """A mission with no token_budget_estimate never warns.""" + run = _make_run(token_budget_estimate=None, tokens_used=0) + result = {"status": "success", "execution": {"tokens_used": 9999}} + + mock_emit = await _call_record_task_result(run, result) + + assert _budget_warnings(mock_emit) == [] From 3033ab1e11381337e2a0f798c3a81f07874bcea6 Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 21:30:08 +0100 Subject: [PATCH 2/7] feat(reliability): configurable power-mode caps via system_settings (PRD-141 US-010) Replace the hardcoded _POWER_MODE_CAPS dict with _get_power_mode_caps(power_mode, db), which reads system_settings('power_modes', ) and merges stored overrides over hardcoded _POWER_MODE_DEFAULTS. Stored keys win; absent keys fall back; unknown modes resolve to 'standard'; any DB error falls back to defaults (never raises on the mission hot path). Caps are resolved once on the serial _prepare_task path and threaded through the prep dict to the DB-free _run_agent_io, preserving the Phase 2 no-DB invariant. --- orchestrator/services/coordinator_service.py | 66 +++++++++++-- .../tests/test_us010_power_mode_caps.py | 94 +++++++++++++++++++ 2 files changed, 151 insertions(+), 9 deletions(-) create mode 100644 orchestrator/tests/test_us010_power_mode_caps.py diff --git a/orchestrator/services/coordinator_service.py b/orchestrator/services/coordinator_service.py index e7eda5e2e..b9ae295a9 100644 --- a/orchestrator/services/coordinator_service.py +++ b/orchestrator/services/coordinator_service.py @@ -17,6 +17,7 @@ """ import asyncio +import json import logging import re from datetime import datetime, timedelta, timezone @@ -28,6 +29,7 @@ from config import COMPLEXITY_TOKEN_BUDGET, Config, config from core.models.core import Agent +from core.models.system_settings import SystemSetting from core.models.orchestration import ( OrchestrationArchive, OrchestrationEvent, @@ -72,14 +74,57 @@ # --------------------------------------------------------------------------- # Mission Power Modes — per-mode caps for model, tokens, and tool iterations. # "standard" is the default when power_mode is absent from mission config. +# These are the hardcoded FALLBACK only. Operators retune live values via +# system_settings (category 'power_modes', key ''); see +# _get_power_mode_caps(). Stored settings win; absent keys fall back here. # --------------------------------------------------------------------------- -_POWER_MODE_CAPS: Dict[str, Dict[str, Any]] = { +_POWER_MODE_DEFAULTS: Dict[str, Dict[str, Any]] = { "light": {"max_tokens": 2_000, "max_tool_iterations": 5, "force_llm_tier": "system_llm"}, "standard": {"max_tokens": 4_000, "max_tool_iterations": 10, "force_llm_tier": None}, "max": {"max_tokens": 16_000, "max_tool_iterations": 50, "force_llm_tier": "orchestrator_llm"}, } +def _get_power_mode_caps(power_mode: str, db: Session) -> Dict[str, Any]: + """Resolve power-mode caps: ``system_settings('power_modes', )`` merged + over ``_POWER_MODE_DEFAULTS``. + + Operators can retune caps at runtime (no deploy) by storing a JSON object + under category ``power_modes``, key ```` — e.g. + ``{"max_tool_iterations": 20}``. Stored keys override the defaults; absent + keys fall back. An unknown mode falls back to ``standard``. + + Must run on the serial DB path (e.g. ``_prepare_task``). Do NOT call from + ``_run_agent_io`` — that runs concurrently via ``asyncio.gather`` with no DB + access; pass the already-resolved caps down instead. + """ + defaults = _POWER_MODE_DEFAULTS.get(power_mode, _POWER_MODE_DEFAULTS["standard"]) + caps: Dict[str, Any] = dict(defaults) # copy — never mutate the module default + + try: + setting = ( + db.query(SystemSetting) + .filter( + SystemSetting.category == "power_modes", + SystemSetting.key == power_mode, + ) + .first() + ) + if setting and setting.value: + override = json.loads(setting.value) + if isinstance(override, dict): + caps.update(override) + except Exception: + logger.warning( + "Could not load power_mode caps for '%s' from system_settings; " + "using hardcoded defaults.", + power_mode, + exc_info=True, + ) + + return caps + + # --------------------------------------------------------------------------- # Synthesis model override (Fix 1) # --------------------------------------------------------------------------- @@ -1017,7 +1062,7 @@ async def _process_run( agent_coros = [ self._run_agent_io(p["factory"], p["agent_id"], p["prompt"], p["task"], p["attachment_ids"], - run_config=p.get("run_config"), + mode_caps=p["mode_caps"], agent_runtime=p.get("agent_runtime")) for p in prepared ] @@ -1321,7 +1366,7 @@ def _sanitize_upstream(raw: str) -> str: factory = AgentFactory(db_session=db) run_config = run.config or {} power_mode = run_config.get("power_mode", "standard") - mode_caps = _POWER_MODE_CAPS.get(power_mode, _POWER_MODE_CAPS["standard"]) + mode_caps = _get_power_mode_caps(power_mode, db) force_tier = mode_caps.get("force_llm_tier") if force_tier: @@ -1360,7 +1405,9 @@ def _sanitize_upstream(raw: str) -> str: "prompt": prompt, "factory": factory, "attachment_ids": task_attachment_ids, - "run_config": run_config, + # Caps resolved here (serial DB path) so the concurrent I/O phase + # never touches the DB — see _get_power_mode_caps / _run_agent_io. + "mode_caps": mode_caps, } async def _run_agent_io( @@ -1370,16 +1417,17 @@ async def _run_agent_io( prompt: str, task: Any, attachment_ids: List[str], - run_config: Optional[Dict[str, Any]] = None, + mode_caps: Optional[Dict[str, Any]] = None, agent_runtime: Optional[Any] = None, ) -> Dict[str, Any]: """Execute agent I/O — safe to run concurrently via asyncio.gather(). - No DB access here — only the LLM + tool loop. + No DB access here — only the LLM + tool loop. ``mode_caps`` is resolved + upstream in _prepare_task (the serial DB phase) and passed in, so this + concurrent path never reads system_settings. """ - power_mode = (run_config or {}).get("power_mode", "standard") - mode_caps = _POWER_MODE_CAPS.get(power_mode, _POWER_MODE_CAPS["standard"]) - max_iters = mode_caps["max_tool_iterations"] + caps = mode_caps or _POWER_MODE_DEFAULTS["standard"] + max_iters = caps["max_tool_iterations"] # Pass runtime directly when we have it so the factory cache can't # swap in a stale cached runtime under us mid-flight. diff --git a/orchestrator/tests/test_us010_power_mode_caps.py b/orchestrator/tests/test_us010_power_mode_caps.py new file mode 100644 index 000000000..b06c4ba6f --- /dev/null +++ b/orchestrator/tests/test_us010_power_mode_caps.py @@ -0,0 +1,94 @@ +""" +PRD-141 US-010: Configurable power-mode caps via system_settings. +================================================================= + +_get_power_mode_caps(power_mode, db) resolves caps as +``system_settings('power_modes', )`` merged over the hardcoded +``_POWER_MODE_DEFAULTS``. Stored keys win; absent keys fall back; an unknown +mode falls back to 'standard'; any DB error falls back to defaults (never +raises on the mission hot path). +""" +import json +import sys +from pathlib import Path +from unittest.mock import MagicMock + +# Ensure orchestrator package is importable +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + +from services.coordinator_service import _get_power_mode_caps, _POWER_MODE_DEFAULTS + + +def _db_returning(setting): + """Mock db whose query(...).filter(...).first() yields ``setting``.""" + db = MagicMock() + db.query.return_value.filter.return_value.first.return_value = setting + return db + + +def _setting_with(value_dict): + setting = MagicMock() + setting.value = json.dumps(value_dict) + return setting + + +def test_power_mode_reads_system_settings(): + """Stored JSON overrides win over the hardcoded defaults.""" + db = _db_returning(_setting_with({"max_tool_iterations": 99, "max_tokens": 12345})) + + caps = _get_power_mode_caps("standard", db) + + assert caps["max_tool_iterations"] == 99 # overridden + assert caps["max_tokens"] == 12345 # overridden + assert caps["force_llm_tier"] is None # untouched 'standard' default + + +def test_power_mode_partial_override_keeps_other_defaults(): + """A partial override only replaces the keys it names.""" + db = _db_returning(_setting_with({"max_tool_iterations": 25})) + + caps = _get_power_mode_caps("light", db) + + assert caps["max_tool_iterations"] == 25 # overridden + assert caps["max_tokens"] == _POWER_MODE_DEFAULTS["light"]["max_tokens"] + assert caps["force_llm_tier"] == "system_llm" # default kept + + +def test_power_mode_falls_back_to_defaults(): + """No stored setting → exactly the hardcoded defaults for that mode.""" + db = _db_returning(None) + + caps = _get_power_mode_caps("max", db) + + assert caps == _POWER_MODE_DEFAULTS["max"] + + +def test_unknown_power_mode_falls_back_to_standard(): + """An unrecognised mode resolves to the 'standard' defaults.""" + db = _db_returning(None) + + caps = _get_power_mode_caps("turbo", db) + + assert caps == _POWER_MODE_DEFAULTS["standard"] + + +def test_power_mode_db_error_falls_back_to_defaults(): + """A DB failure must not raise on the mission path — defaults are used.""" + db = MagicMock() + db.query.side_effect = RuntimeError("db down") + + caps = _get_power_mode_caps("light", db) + + assert caps == _POWER_MODE_DEFAULTS["light"] + + +def test_get_power_mode_caps_does_not_mutate_defaults(): + """Resolution copies the default dict; it never mutates the module const.""" + db = _db_returning(_setting_with({"max_tokens": 999999})) + before = dict(_POWER_MODE_DEFAULTS["standard"]) + + _get_power_mode_caps("standard", db) + + assert _POWER_MODE_DEFAULTS["standard"] == before From 269c775095086993157cb5cfb350163600b20445 Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 21:37:39 +0100 Subject: [PATCH 3/7] feat(reliability): model-proportional context budgets (PRD-141 US-011) Add ContextRouter._compute_budgets(context_window): when the model context window is known, each section gets a fixed proportion of the usable window (usable = 80% of raw); when unknown, fall back to the static CONTEXT_BUDGET_* config values. Weights: session 10%, long_term 15%, temporal 10%, daily 8%, awareness 5%, tools 20%, system_prompt 12% (sum 0.80 of usable, leaving slack). retrieve_context now takes an optional context_window and sources its budgets from _compute_budgets, so config is fallback-only when a window is available. Threaded the optional param through UnifiedMemoryService.retrieve_context. Added CONTEXT_BUDGET_TOOLS/SYSTEM_PROMPT to config to complete the fallback set. --- orchestrator/config.py | 7 +- orchestrator/modules/memory/context_router.py | 69 +++++++++++++-- .../modules/memory/unified_memory_service.py | 2 + .../tests/test_us011_context_budgets.py | 86 +++++++++++++++++++ 4 files changed, 157 insertions(+), 7 deletions(-) create mode 100644 orchestrator/tests/test_us011_context_budgets.py diff --git a/orchestrator/config.py b/orchestrator/config.py index 4bfc7014d..fc6b6cb01 100644 --- a/orchestrator/config.py +++ b/orchestrator/config.py @@ -87,12 +87,17 @@ def REDIS_URL(self) -> str: MEMORY_SESSION_CONSOLIDATION_TTL_SECONDS: int = int(os.getenv("MEMORY_SESSION_CONSOLIDATION_TTL_SECONDS", "3600")) # L3 Cache: TTL for Mem0 search result caching in Redis MEMORY_CACHE_TTL_SECONDS: int = int(os.getenv("MEMORY_CACHE_TTL_SECONDS", "300")) - # Context Router: per-source sub-budgets (tokens) + # Context Router: per-source sub-budgets (tokens). + # Fallback only — used when the model context window is unknown. When the + # window is known, ContextRouter._compute_budgets derives budgets as a + # proportion of the usable window instead (PRD-141 US-011). CONTEXT_BUDGET_SESSION: int = int(os.getenv("CONTEXT_BUDGET_SESSION", "500")) CONTEXT_BUDGET_LONG_TERM: int = int(os.getenv("CONTEXT_BUDGET_LONG_TERM", "800")) CONTEXT_BUDGET_TEMPORAL: int = int(os.getenv("CONTEXT_BUDGET_TEMPORAL", "600")) CONTEXT_BUDGET_DAILY: int = int(os.getenv("CONTEXT_BUDGET_DAILY", "400")) CONTEXT_BUDGET_AWARENESS: int = int(os.getenv("CONTEXT_BUDGET_AWARENESS", "200")) + CONTEXT_BUDGET_TOOLS: int = int(os.getenv("CONTEXT_BUDGET_TOOLS", "1000")) + CONTEXT_BUDGET_SYSTEM_PROMPT: int = int(os.getenv("CONTEXT_BUDGET_SYSTEM_PROMPT", "600")) # Knowledge awareness: TTL for per-workspace capability map cached in Redis MEMORY_AWARENESS_CACHE_TTL_SECONDS: int = int(os.getenv("MEMORY_AWARENESS_CACHE_TTL_SECONDS", "600")) # L2 Decay: Ebbinghaus decay rate (higher = faster forgetting) diff --git a/orchestrator/modules/memory/context_router.py b/orchestrator/modules/memory/context_router.py index fc6825fd6..9dfc36846 100644 --- a/orchestrator/modules/memory/context_router.py +++ b/orchestrator/modules/memory/context_router.py @@ -33,6 +33,27 @@ logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Context budget weights (PRD-141 US-011) +# --------------------------------------------------------------------------- +# Each section gets a fixed proportion of the *usable* context window +# (usable = 80% of the raw window, reserving 20% for the model's response). +# Weights sum to 0.80 of the usable window — the remaining 0.20 is slack for +# estimator error and untracked overhead. ``tools`` and ``system_prompt`` are +# reserved headroom the router does not fill itself — they keep the memory +# sections from claiming space the prompt assembler needs. +_CONTEXT_BUDGET_WEIGHTS: Dict[str, float] = { + "session": 0.10, + "long_term": 0.15, + "temporal": 0.10, + "daily": 0.08, + "awareness": 0.05, + "tools": 0.20, + "system_prompt": 0.12, +} +_USABLE_WINDOW_FRACTION = 0.80 + + # --------------------------------------------------------------------------- # ContextSignals — output of query analysis # --------------------------------------------------------------------------- @@ -345,6 +366,41 @@ def _estimate_tokens(text: str) -> int: """Cheap token estimate: ~4 chars per token.""" return len(text) // 4 if text else 0 + @staticmethod + def _compute_budgets(context_window: Optional[int]) -> Dict[str, int]: + """Resolve per-section token budgets for context assembly. + + When the model's ``context_window`` is known (a positive int), each + section gets a fixed proportion of the *usable* window + (``usable = int(context_window * 0.80)``), so budgets scale with the + model — a 128K model gets far larger sections than an 8K model. + + When the window is unknown (``None`` / non-positive), fall back to the + static ``CONTEXT_BUDGET_*`` config values. The config values are + therefore a fallback only, never the primary source when a window is + available. + + Returns a dict keyed by section name with token budgets. + """ + from config import config + + if not context_window or context_window <= 0: + return { + "session": config.CONTEXT_BUDGET_SESSION, + "long_term": config.CONTEXT_BUDGET_LONG_TERM, + "temporal": config.CONTEXT_BUDGET_TEMPORAL, + "daily": config.CONTEXT_BUDGET_DAILY, + "awareness": config.CONTEXT_BUDGET_AWARENESS, + "tools": config.CONTEXT_BUDGET_TOOLS, + "system_prompt": config.CONTEXT_BUDGET_SYSTEM_PROMPT, + } + + usable = int(context_window * _USABLE_WINDOW_FRACTION) + return { + name: int(usable * weight) + for name, weight in _CONTEXT_BUDGET_WEIGHTS.items() + } + @staticmethod def _truncate_to_budget(text: str, token_budget: int) -> str: """Truncate *text* so its estimated token count fits within *token_budget*.""" @@ -384,6 +440,7 @@ async def retrieve_context( agent_id: int, query: str, conversation_id: Optional[str] = None, + context_window: Optional[int] = None, ) -> ContextBundle: """ Assemble a budget-constrained context bundle by fetching from L1/L2/L3 @@ -400,17 +457,17 @@ async def retrieve_context( All layer fetches are concurrent via ``asyncio.gather``. Any single-layer failure is logged and skipped — never breaks the bundle. """ - from config import config from modules.memory.unified_memory_service import get_unified_memory_service service = get_unified_memory_service() signals = self.analyze_query(query) - budget_session = config.CONTEXT_BUDGET_SESSION - budget_long_term = config.CONTEXT_BUDGET_LONG_TERM - budget_temporal = config.CONTEXT_BUDGET_TEMPORAL - budget_daily = config.CONTEXT_BUDGET_DAILY - budget_awareness = config.CONTEXT_BUDGET_AWARENESS + budgets = self._compute_budgets(context_window) + budget_session = budgets["session"] + budget_long_term = budgets["long_term"] + budget_temporal = budgets["temporal"] + budget_daily = budgets["daily"] + budget_awareness = budgets["awareness"] # ----- Determine which fetches to launch ----- fetch_session = ( diff --git a/orchestrator/modules/memory/unified_memory_service.py b/orchestrator/modules/memory/unified_memory_service.py index af571bab7..ed6fc5511 100644 --- a/orchestrator/modules/memory/unified_memory_service.py +++ b/orchestrator/modules/memory/unified_memory_service.py @@ -1662,6 +1662,7 @@ async def retrieve_context( agent_id: int, query: str, conversation_id: Optional[str] = None, + context_window: Optional[int] = None, ) -> "ContextBundle": """ Assemble a budget-constrained context bundle across all memory layers. @@ -1683,6 +1684,7 @@ async def retrieve_context( agent_id=agent_id, query=query, conversation_id=conversation_id, + context_window=context_window, ) except Exception: logger.error( diff --git a/orchestrator/tests/test_us011_context_budgets.py b/orchestrator/tests/test_us011_context_budgets.py new file mode 100644 index 000000000..6da668742 --- /dev/null +++ b/orchestrator/tests/test_us011_context_budgets.py @@ -0,0 +1,86 @@ +""" +PRD-141 US-011: Model-proportional context budgets. +==================================================== + +``ContextRouter._compute_budgets(context_window)`` resolves per-section token +budgets. When the model context window is known, each section is a fixed +proportion of the *usable* window (80% of the raw window). When the window is +unknown, the static ``CONTEXT_BUDGET_*`` config values are the fallback — +never the primary source when a window is available. +""" +import sys +from pathlib import Path + +# Ensure orchestrator package is importable +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + +from config import config +from modules.memory.context_router import ( + ContextRouter, + _CONTEXT_BUDGET_WEIGHTS, + _USABLE_WINDOW_FRACTION, +) + +_SECTIONS = ("session", "long_term", "temporal", "daily", "awareness", "tools", "system_prompt") + + +def test_context_budgets_scale_with_model(): + """A 128K model gets strictly larger section budgets than an 8K model.""" + small = ContextRouter._compute_budgets(8_000) + large = ContextRouter._compute_budgets(128_000) + + for section in _SECTIONS: + assert large[section] > small[section], ( + f"{section}: 128K budget ({large[section]}) should exceed " + f"8K budget ({small[section]})" + ) + + +def test_context_budgets_are_proportions_of_usable_window(): + """Each section equals its weight × usable window (usable = 80% of raw).""" + window = 128_000 + usable = int(window * _USABLE_WINDOW_FRACTION) + + budgets = ContextRouter._compute_budgets(window) + + for section, weight in _CONTEXT_BUDGET_WEIGHTS.items(): + assert budgets[section] == int(usable * weight) + + +def test_context_budgets_fallback_to_defaults(): + """An unknown window falls back to exactly the static config values.""" + budgets = ContextRouter._compute_budgets(None) + + assert budgets["session"] == config.CONTEXT_BUDGET_SESSION + assert budgets["long_term"] == config.CONTEXT_BUDGET_LONG_TERM + assert budgets["temporal"] == config.CONTEXT_BUDGET_TEMPORAL + assert budgets["daily"] == config.CONTEXT_BUDGET_DAILY + assert budgets["awareness"] == config.CONTEXT_BUDGET_AWARENESS + assert budgets["tools"] == config.CONTEXT_BUDGET_TOOLS + assert budgets["system_prompt"] == config.CONTEXT_BUDGET_SYSTEM_PROMPT + + +def test_context_budgets_nonpositive_window_falls_back(): + """Zero / negative windows are treated as unknown → config fallback.""" + fallback = ContextRouter._compute_budgets(None) + + assert ContextRouter._compute_budgets(0) == fallback + assert ContextRouter._compute_budgets(-1) == fallback + + +def test_context_budgets_always_cover_all_sections(): + """Both the proportional and fallback paths return all seven sections.""" + proportional = ContextRouter._compute_budgets(32_000) + fallback = ContextRouter._compute_budgets(None) + + assert set(proportional) == set(_SECTIONS) + assert set(fallback) == set(_SECTIONS) + + +def test_context_budget_weights_leave_response_slack(): + """Weights sum to 0.80 of usable — the rest is slack, never over-allocated.""" + total = sum(_CONTEXT_BUDGET_WEIGHTS.values()) + assert abs(total - 0.80) < 1e-9 + assert total < 1.0 # must never claim the whole usable window From be5a4cd4c0e6a1d75b7eff505dda2f33527cbbe1 Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 21:43:09 +0100 Subject: [PATCH 4/7] feat(reliability): adaptive context-guard thresholds (PRD-141 US-012) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add _thresholds_for_model(context_window) -> (compact_threshold, keep_recent_turns), tiered: >=200K (0.90,12), >=100K (0.85,8), >=32K (0.80,6), >=8K (0.75,4), else (0.70,3); unknown/non-positive window falls back to the static COMPACT_THRESHOLD / KEEP_RECENT_TURNS. check_and_compact now derives both values from the model window and threads keep_recent_turns into _compact, so large-context models compact later and keep more turns while small-context models compact earlier — preventing context_length_exceeded and runaway memory. Also fix a latent bug in _compact's tombstone: it called self.count_tokens / self.count_message_tokens (module-level functions, not methods) and passed a single dict where a list was expected — any real compaction would have raised AttributeError. Now uses count_tokens(summary) / count_message_tokens(old_turns). --- orchestrator/core/context_guard.py | 55 ++++++-- orchestrator/tests/test_context_guard.py | 154 +++++++++++++++++++++++ 2 files changed, 200 insertions(+), 9 deletions(-) create mode 100644 orchestrator/tests/test_context_guard.py diff --git a/orchestrator/core/context_guard.py b/orchestrator/core/context_guard.py index 5e90878f1..fb92c5694 100644 --- a/orchestrator/core/context_guard.py +++ b/orchestrator/core/context_guard.py @@ -7,8 +7,11 @@ Strategy: 1. Count tokens in the full message payload (system + user + assistant + tool) -2. If below 80% of model context → pass through unchanged -3. If above 80% → compact: summarize older turns, keep recent context +2. Resolve a model-aware compaction threshold + kept-turns from the context + window (_thresholds_for_model): bigger windows compact later and keep more + turns; small windows compact earlier and keep fewer +3. If below the threshold → pass through unchanged; above → compact: summarize + older turns, keep recent context 4. Flush key facts to Mem0 before discarding messages This prevents context_length_exceeded errors and keeps conversations going @@ -137,11 +140,42 @@ def get_context_window(model_name: str, db_session=None) -> int: # Context Guard # --------------------------------------------------------------------------- -# Thresholds +# Thresholds — static fallbacks only. The model-aware values come from +# _thresholds_for_model(); these are used when the context window is unknown. COMPACT_THRESHOLD = 0.80 # Compact when >80% of context used KEEP_RECENT_TURNS = 6 # Always keep the last N user+assistant messages SUMMARY_MAX_TOKENS = 500 # Max tokens for the compaction summary + +def _thresholds_for_model(context_window: int) -> Tuple[float, int]: + """Resolve (compact_threshold, keep_recent_turns) for a model's window. + + Large-context models can safely fill a higher fraction of the window and + keep more recent turns; small-context models must compact earlier and keep + fewer turns to avoid context_length_exceeded (provider 400s). + + Tiers (PRD-141 US-012): + >=200K -> (0.90, 12) + >=100K -> (0.85, 8) + >= 32K -> (0.80, 6) + >= 8K -> (0.75, 4) + else -> (0.70, 3) + + An unknown / non-positive window falls back to the static COMPACT_THRESHOLD + and KEEP_RECENT_TURNS constants. + """ + if not context_window or context_window <= 0: + return COMPACT_THRESHOLD, KEEP_RECENT_TURNS + if context_window >= 200_000: + return 0.90, 12 + if context_window >= 100_000: + return 0.85, 8 + if context_window >= 32_000: + return 0.80, 6 + if context_window >= 8_000: + return 0.75, 4 + return 0.70, 3 + # PRD-123 Pattern #7: Proactive compaction thresholds PROACTIVE_COMPACT_AFTER_TURNS = int( __import__("os").getenv("PROACTIVE_COMPACT_AFTER_TURNS", "8") @@ -182,10 +216,11 @@ async def check_and_compact( (messages, was_compacted, tools) — tools may be None if they don't fit """ context_window = get_context_window(model_name, db_session) + compact_threshold, keep_recent_turns = _thresholds_for_model(context_window) tool_tokens = count_tool_tokens(tools) current_tokens = count_message_tokens(messages) total_tokens = current_tokens + tool_tokens - threshold = int(context_window * COMPACT_THRESHOLD) + threshold = int(context_window * compact_threshold) logger.debug( "[ContextGuard] tokens=%d (msgs=%d tools=%d) / %d (%.0f%% of %d window)", @@ -217,6 +252,7 @@ async def check_and_compact( llm_manager=llm_manager, workspace_id=workspace_id, agent_id=agent_id, + keep_recent_turns=keep_recent_turns, ) new_tokens = count_message_tokens(compacted) @@ -233,6 +269,7 @@ async def _compact( llm_manager: Any, workspace_id: Optional[str] = None, agent_id: Optional[int] = None, + keep_recent_turns: int = KEEP_RECENT_TURNS, ) -> List[Dict[str, Any]]: """ Compact messages by summarizing older turns. @@ -253,12 +290,12 @@ async def _compact( conversation.append(msg) # If conversation is short enough, keep everything - if len(conversation) <= KEEP_RECENT_TURNS: + if len(conversation) <= keep_recent_turns: return messages # Split: old turns (to summarize) | recent turns (to keep) - old_turns = conversation[:-KEEP_RECENT_TURNS] - recent_turns = conversation[-KEEP_RECENT_TURNS:] + old_turns = conversation[:-keep_recent_turns] + recent_turns = conversation[-keep_recent_turns:] # Build text from old turns for summarization old_text = self._turns_to_text(old_turns) @@ -283,8 +320,8 @@ async def _compact( "_compact_tombstone": { "compacted_count": len(old_turns), "compacted_roles": [m.get("role") for m in old_turns], - "summary_token_est": self.count_tokens(summary), - "original_token_est": sum(self.count_message_tokens(m) for m in old_turns), + "summary_token_est": count_tokens(summary), + "original_token_est": count_message_tokens(old_turns), }, } diff --git a/orchestrator/tests/test_context_guard.py b/orchestrator/tests/test_context_guard.py new file mode 100644 index 000000000..b8580f47e --- /dev/null +++ b/orchestrator/tests/test_context_guard.py @@ -0,0 +1,154 @@ +""" +PRD-141 US-012: Adaptive context-guard thresholds. +=================================================== + +``_thresholds_for_model(context_window)`` returns a +``(compact_threshold, keep_recent_turns)`` tuple that scales with the model's +context window — large windows compact later and keep more turns; small +windows compact earlier and keep fewer, avoiding context_length_exceeded +(provider 400s) and runaway memory. + +Tiers: + >=200K -> (0.90, 12) + >=100K -> (0.85, 8) + >= 32K -> (0.80, 6) + >= 8K -> (0.75, 4) + else -> (0.70, 3) + +An unknown / non-positive window falls back to the static COMPACT_THRESHOLD / +KEEP_RECENT_TURNS constants. +""" +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +# Ensure orchestrator package is importable +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + +from core.context_guard import ( + ContextGuard, + COMPACT_THRESHOLD, + KEEP_RECENT_TURNS, + _thresholds_for_model, +) + + +# --------------------------------------------------------------------------- +# Pure threshold resolution +# --------------------------------------------------------------------------- + +def test_thresholds_exact_tiers(): + """Each documented tier returns its exact (threshold, keep_turns) tuple.""" + assert _thresholds_for_model(200_000) == (0.90, 12) + assert _thresholds_for_model(128_000) == (0.85, 8) + assert _thresholds_for_model(100_000) == (0.85, 8) + assert _thresholds_for_model(32_000) == (0.80, 6) + assert _thresholds_for_model(8_000) == (0.75, 4) + assert _thresholds_for_model(4_000) == (0.70, 3) + + +def test_tier_boundaries_are_inclusive_lower_bounds(): + """A window exactly on a boundary takes that tier; one below drops a tier.""" + assert _thresholds_for_model(200_000) == (0.90, 12) + assert _thresholds_for_model(199_999) == (0.85, 8) + assert _thresholds_for_model(100_000) == (0.85, 8) + assert _thresholds_for_model(99_999) == (0.80, 6) + assert _thresholds_for_model(32_000) == (0.80, 6) + assert _thresholds_for_model(31_999) == (0.75, 4) + assert _thresholds_for_model(8_000) == (0.75, 4) + assert _thresholds_for_model(7_999) == (0.70, 3) + + +def test_compact_threshold_adapts_to_context(): + """The compaction threshold rises monotonically with the context window.""" + windows = [4_000, 8_000, 32_000, 100_000, 200_000] + thresholds = [_thresholds_for_model(w)[0] for w in windows] + + assert thresholds == sorted(thresholds) # non-decreasing + assert thresholds[0] < thresholds[-1] # genuinely adapts + assert _thresholds_for_model(200_000)[0] > _thresholds_for_model(8_000)[0] + + +def test_keep_recent_turns_adapts(): + """Kept-turns rises monotonically with the context window.""" + windows = [4_000, 8_000, 32_000, 100_000, 200_000] + kept = [_thresholds_for_model(w)[1] for w in windows] + + assert kept == sorted(kept) # non-decreasing + assert kept[0] < kept[-1] # genuinely adapts + assert _thresholds_for_model(200_000)[1] > _thresholds_for_model(8_000)[1] + + +def test_thresholds_fallback_on_unknown_window(): + """An unknown / non-positive window falls back to the static constants.""" + fallback = (COMPACT_THRESHOLD, KEEP_RECENT_TURNS) + assert _thresholds_for_model(None) == fallback + assert _thresholds_for_model(0) == fallback + assert _thresholds_for_model(-1) == fallback + + +def test_small_window_compacts_more_aggressively_than_large(): + """Safety property: a small window must NOT use a larger budget than a big one.""" + small_thr, small_keep = _thresholds_for_model(8_000) + large_thr, large_keep = _thresholds_for_model(200_000) + assert small_thr <= large_thr + assert small_keep <= large_keep + + +# --------------------------------------------------------------------------- +# keep_recent_turns actually flows into compaction +# --------------------------------------------------------------------------- + +def _conversation(n_turns: int): + """A system message + n alternating user/assistant turns.""" + msgs = [{"role": "system", "content": "You are a helpful agent."}] + for i in range(n_turns): + role = "user" if i % 2 == 0 else "assistant" + msgs.append({"role": role, "content": f"message {i}"}) + return msgs + + +@pytest.mark.asyncio +async def test_compact_respects_keep_recent_turns(): + """_compact keeps exactly keep_recent_turns recent messages and compacts the rest.""" + guard = ContextGuard() + guard._summarize = AsyncMock(return_value="SUMMARY") + messages = _conversation(20) # 1 system + 20 turns + + compacted = await guard._compact( + messages=messages, + llm_manager=MagicMock(), + workspace_id=None, # skip memory flush + keep_recent_turns=4, + ) + + # result = system_msgs + [tombstone] + recent_turns + recent = [m for m in compacted if not m.get("_compact_tombstone") and m.get("role") != "system"] + assert len(recent) == 4 + assert recent[-1]["content"] == "message 19" + tombstone = next(m for m in compacted if m.get("_compact_tombstone")) + assert tombstone["_compact_tombstone"]["compacted_count"] == 16 # 20 - 4 + + +@pytest.mark.asyncio +async def test_larger_keep_recent_turns_compacts_fewer(): + """A larger keep_recent_turns preserves more turns → compacts fewer.""" + guard = ContextGuard() + guard._summarize = AsyncMock(return_value="SUMMARY") + messages = _conversation(20) + + compacted = await guard._compact( + messages=messages, + llm_manager=MagicMock(), + workspace_id=None, + keep_recent_turns=12, + ) + + recent = [m for m in compacted if not m.get("_compact_tombstone") and m.get("role") != "system"] + assert len(recent) == 12 + tombstone = next(m for m in compacted if m.get("_compact_tombstone")) + assert tombstone["_compact_tombstone"]["compacted_count"] == 8 # 20 - 12 From f8ffabbc1deecc3e3ecf3be7ffd00083548e056f Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 22:18:15 +0100 Subject: [PATCH 5/7] feat(reliability): add ActionRegistry.get_by_tags (PRD-141 US-013) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add get_by_tags(tags) -> List[ActionDefinition] with OR semantics (an action matches if it carries any requested tag); calls _ensure_initialized() first, matching get_by_category (which already existed). Pure-additive — no behaviour change to existing registry consumers. Unblocks US-015's registry-driven intent->category filtering. --- .../tools/discovery/action_registry.py | 6 + .../test_us013_action_registry_lookups.py | 119 ++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 orchestrator/tests/test_us013_action_registry_lookups.py diff --git a/orchestrator/modules/tools/discovery/action_registry.py b/orchestrator/modules/tools/discovery/action_registry.py index cced2226f..0939812b2 100644 --- a/orchestrator/modules/tools/discovery/action_registry.py +++ b/orchestrator/modules/tools/discovery/action_registry.py @@ -93,6 +93,12 @@ def get_by_category(self, category: str) -> List[ActionDefinition]: self._ensure_initialized() return [a for a in self._actions.values() if a.category == category] + def get_by_tags(self, tags: List[str]) -> List[ActionDefinition]: + """Get actions whose tags intersect any of *tags* (OR semantics).""" + self._ensure_initialized() + wanted = set(tags) + return [a for a in self._actions.values() if wanted.intersection(a.tags)] + def get_by_permission(self, level: str) -> List[ActionDefinition]: """Get actions filtered by permission level.""" self._ensure_initialized() diff --git a/orchestrator/tests/test_us013_action_registry_lookups.py b/orchestrator/tests/test_us013_action_registry_lookups.py new file mode 100644 index 000000000..c7706d743 --- /dev/null +++ b/orchestrator/tests/test_us013_action_registry_lookups.py @@ -0,0 +1,119 @@ +""" +PRD-141 US-013: ActionRegistry category/tag lookups. +===================================================== + +``get_by_category(category)`` and ``get_by_tags(tags)`` let the tool pipeline +filter actions without a hardcoded dict. Both call ``_ensure_initialized()`` +first. ``get_by_tags`` uses OR semantics — an action matches if it carries any +of the requested tags. + +The ``modules.tools`` package ``__init__`` eagerly imports the executor chain +(DB-backed), so we load ``action_registry.py`` as an isolated leaf module — it +only uses stdlib at import time. Tests register actions directly and flip the +``_initialized`` flag, so ``_ensure_initialized`` never triggers the (DB-backed) +platform-action load. +""" +import importlib.util +import sys +from pathlib import Path + +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + + +def _load_action_registry_module(): + path = _orchestrator_root / "modules" / "tools" / "discovery" / "action_registry.py" + spec = importlib.util.spec_from_file_location("_us013_action_registry", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +_mod = _load_action_registry_module() +ActionRegistry = _mod.ActionRegistry +ActionDefinition = _mod.ActionDefinition + + +def _action(name, category, tags=None): + return ActionDefinition( + name=name, + description=f"{name} description", + category=category, + parameters={"type": "object", "properties": {}}, + tags=tags or [], + ) + + +def _registry_with(*actions): + """A registry pre-loaded with *actions*, with init short-circuited.""" + reg = ActionRegistry() + for a in actions: + reg.register(a) + reg._initialized = True # skip the DB-backed platform-action load + return reg + + +def test_action_registry_get_by_category(): + """get_by_category('harness') returns only harness-category actions.""" + reg = _registry_with( + _action("harness_status", "harness", tags=["monitoring"]), + _action("harness_restart", "harness"), + _action("list_agents", "agents"), + ) + + result = reg.get_by_category("harness") + + names = {a.name for a in result} + assert names == {"harness_status", "harness_restart"} + assert all(a.category == "harness" for a in result) + + +def test_action_registry_get_by_tags(): + """get_by_tags(['monitoring']) returns only actions tagged 'monitoring'.""" + reg = _registry_with( + _action("harness_status", "harness", tags=["monitoring", "health"]), + _action("metrics_read", "analytics", tags=["monitoring"]), + _action("list_agents", "agents", tags=["agents"]), + ) + + result = reg.get_by_tags(["monitoring"]) + + names = {a.name for a in result} + assert names == {"harness_status", "metrics_read"} + + +def test_get_by_tags_is_or_semantics(): + """Multiple requested tags union their matches (OR, not AND).""" + reg = _registry_with( + _action("a", "x", tags=["monitoring"]), + _action("b", "x", tags=["billing"]), + _action("c", "x", tags=["unrelated"]), + ) + + result = reg.get_by_tags(["monitoring", "billing"]) + + assert {a.name for a in result} == {"a", "b"} + + +def test_lookups_return_empty_when_no_match(): + reg = _registry_with(_action("a", "x", tags=["monitoring"])) + assert reg.get_by_category("nope") == [] + assert reg.get_by_tags(["nope"]) == [] + + +def test_lookups_trigger_initialization(): + """A fresh (uninitialized) registry initializes lazily on lookup.""" + reg = ActionRegistry() + calls = {"n": 0} + + def fake_init(): + calls["n"] += 1 + reg._initialized = True + + reg._ensure_initialized = fake_init + + reg.get_by_category("harness") + reg.get_by_tags(["monitoring"]) + + assert calls["n"] == 2 # both lookups went through _ensure_initialized From c82437bc0ac991573367532eeb0798edf3943dc1 Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 22:31:24 +0100 Subject: [PATCH 6/7] refactor(routing): SmartToolRouter delegates ranking to GraphRouter (PRD-141 US-014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete the chatbot router's own embedding path (_ensure_embeddings, _rank_tools_by_similarity, the 4 embedding-state attrs, and the core.math.vector_operations / core.llm.embedding_manager imports). When SEMANTIC_TOOL_ROUTING is on, route() now delegates to GraphRouter.rank_chains — the single tool-selection pipeline — and filters available_tools to the ranked chain actions, always preserving CORE_TOOLS, ALWAYS_INCLUDE, and the classifier's suggested tools. agent_id is threaded from ContextService so GraphRouter can scope per-agent edges/affinities. On GraphRouter failure we emit a structured record_error(subsystem="routing") and fall back to category filtering instead of a bare warning. --- .../consumers/chatbot/smart_tool_router.py | 168 +++-------- .../modules/context/sections/tools.py | 1 + .../test_us014_graph_router_delegation.py | 281 ++++++++++++++++++ 3 files changed, 325 insertions(+), 125 deletions(-) create mode 100644 orchestrator/tests/test_us014_graph_router_delegation.py diff --git a/orchestrator/consumers/chatbot/smart_tool_router.py b/orchestrator/consumers/chatbot/smart_tool_router.py index 3bc91f181..e67cf4ca3 100644 --- a/orchestrator/consumers/chatbot/smart_tool_router.py +++ b/orchestrator/consumers/chatbot/smart_tool_router.py @@ -8,7 +8,9 @@ - HOW to prioritize tool selection Works with the Intent Classifier to route appropriately. -Supports embedding-based semantic ranking (PRD-64) with fallback to keyword matching. +Delegates semantic ranking to GraphRouter (PRD-141 US-014) — the single +tool-selection pipeline — and falls back to keyword/category matching when the +graph is unavailable. NOTE: This module is consumed by ContextService (ToolLoadingStrategy.FILTERED). It is the only caller. Future work: absorb filtering logic into @@ -16,9 +18,8 @@ See PRD-81 Task 5.3 and system audit R1 finding. """ -import asyncio import logging -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, List, Optional, Any from dataclasses import dataclass from .intent_classifier import Intent, IntentResult, get_intent_classifier @@ -46,8 +47,9 @@ class SmartToolRouter: - Prefer internal tools for internal data - Use Composio for external app actions - PRD-64: Supports embedding-based semantic ranking when SEMANTIC_TOOL_ROUTING=true. - Falls back to keyword-based category matching if embeddings are unavailable. + PRD-141 US-014: when SEMANTIC_TOOL_ROUTING=true, ranking is delegated to + GraphRouter (the single tool-selection pipeline). Falls back to keyword-based + category matching if the graph is unavailable. """ # Core tools that are almost always useful @@ -130,110 +132,13 @@ class SmartToolRouter: def __init__(self): self.classifier = get_intent_classifier() - # Embedding-based semantic ranking state (PRD-64) - self._embedding_manager = None - self._tool_embeddings: Dict[str, List[float]] = {} # tool_name -> embedding - self._embeddings_initialized = False - self._embeddings_init_lock = asyncio.Lock() - - async def _ensure_embeddings(self, available_tools: List[Dict[str, Any]]) -> bool: - """ - Lazy-initialize tool embeddings on first route() call. - Returns True if semantic ranking is available. - """ - from config import config - if not config.SEMANTIC_TOOL_ROUTING: - return False - - async with self._embeddings_init_lock: - # Collect tool names that need embedding - tools_needing_embed = [] - for tool in available_tools: - fn = tool.get("function", {}) - name = fn.get("name", "") - if name and name not in self._tool_embeddings: - desc = fn.get("description", "") - tools_needing_embed.append((name, f"{name}: {desc}")) - - if not tools_needing_embed and self._embeddings_initialized: - return True - - try: - if self._embedding_manager is None: - from core.llm.embedding_manager import get_embedding_manager - self._embedding_manager = get_embedding_manager() - - if tools_needing_embed: - texts = [t[1] for t in tools_needing_embed] - embeddings = await self._embedding_manager.generate_embeddings_batch(texts) - for (name, _), embedding in zip(tools_needing_embed, embeddings): - self._tool_embeddings[name] = embedding - - logger.info( - f"[ToolRouter] Embedded {len(tools_needing_embed)} new tools " - f"(total cached: {len(self._tool_embeddings)})" - ) - - self._embeddings_initialized = True - return True - - except Exception as e: - logger.warning(f"[ToolRouter] Embedding init failed, falling back to keyword matching: {e}") - return False - - async def _rank_tools_by_similarity( - self, - query: str, - available_tools: List[Dict[str, Any]], - intent_result: IntentResult, - max_tools: int = 30, - ) -> List[Dict[str, Any]]: - """ - Rank tools by cosine similarity between query embedding and tool embeddings. - Applies intent boost and core tool boost. - """ - from core.math.vector_operations import VectorOperations - - # Generate query embedding - query_embedding = await self._embedding_manager.generate_embedding(query) - - # Score each tool - scored: List[Tuple[float, Dict[str, Any]]] = [] - for tool in available_tools: - fn = tool.get("function", {}) - name = fn.get("name", "") - - tool_emb = self._tool_embeddings.get(name) - if not tool_emb: - continue - - score = float(VectorOperations.cosine_similarity(query_embedding, tool_emb)) - - # Boost for intent-suggested tools - if name in (intent_result.suggested_tools or []): - score += 0.15 - - # Slight boost for core tools - if name in self.CORE_TOOLS: - score += 0.05 - - scored.append((score, tool)) - - # Sort by score descending - scored.sort(key=lambda x: -x[0]) - - # Log top 5 for debugging - top_debug = [(t.get("function", {}).get("name", "?"), f"{s:.3f}") for s, t in scored[:5]] - logger.info(f"[ToolRouter] Semantic ranking top-5: {top_debug}") - - return [tool for _, tool in scored[:max_tools]] - async def route( self, query: str, available_tools: List[Dict[str, Any]], conversation_context: Optional[List[Dict]] = None, tool_hints: Optional[List[str]] = None, + agent_id: Optional[int] = None, ) -> ToolRoutingResult: """ Route a query to appropriate tools. @@ -243,6 +148,7 @@ async def route( available_tools: All tools available to the agent conversation_context: Recent conversation history tool_hints: PRD-68 hint keywords from AutoBrain (e.g. ["email", "github"]) + agent_id: Owning agent — scopes GraphRouter's per-agent edges/affinities Returns: ToolRoutingResult with filtered tools and guidance @@ -291,32 +197,44 @@ async def route( reasoning=intent_result.reasoning ) - # Try semantic ranking first (PRD-64) - semantic_available = await self._ensure_embeddings(available_tools) - - if semantic_available: + # PRD-141 US-014: delegate semantic ranking to GraphRouter — the single + # tool-selection pipeline. It wraps ActionSemanticIndex + the tool-routing + # graph and internally falls back to embedding-only ranking when the graph + # is empty. An empty result or any failure drops through to category + # filtering below. CORE_TOOLS / ALWAYS_INCLUDE / classifier-suggested tools + # are always kept so the graph path never strips a tool the agent needs. + from config import config + if config.SEMANTIC_TOOL_ROUTING: try: - filtered = await self._rank_tools_by_similarity( - query, available_tools, intent_result - ) - # Ensure ALWAYS_INCLUDE tools are present even after semantic ranking - filtered_names = {t.get("function", {}).get("name") for t in filtered} - for tool in available_tools: - name = tool.get("function", {}).get("name", "") - if name in self.ALWAYS_INCLUDE and name not in filtered_names: - filtered.append(tool) - tool_choice = self._determine_tool_choice(intent_result, filtered) - priority = intent_result.suggested_tools or [] + from modules.tools.discovery.graph_router import get_graph_router - return ToolRoutingResult( - should_include_tools=True, - filtered_tools=filtered, - priority_tools=priority, - tool_choice=tool_choice, - reasoning=f"Semantic ranking: {intent_result.reasoning}" + chains = await get_graph_router().rank_chains( + query=query, agent_id=agent_id, top_k=30, ) + if chains: + keep = {name for _, _, chain in chains for name in chain} + keep |= self.CORE_TOOLS | self.ALWAYS_INCLUDE + keep |= set(intent_result.suggested_tools or []) + filtered = [ + t for t in available_tools + if t.get("function", {}).get("name", "") in keep + ] + if filtered: + return ToolRoutingResult( + should_include_tools=True, + filtered_tools=filtered, + priority_tools=intent_result.suggested_tools or [], + tool_choice=self._determine_tool_choice(intent_result, filtered), + reasoning=f"Graph routing: {intent_result.reasoning}", + ) except Exception as e: - logger.warning(f"[ToolRouter] Semantic ranking failed, falling back: {e}") + from core.utils.exception_telemetry import record_error + record_error( + subsystem="routing", + operation="graph_rank_chains", + error=e, + agent_id=agent_id, + ) # Fallback: keyword-based category matching relevant_categories = self.INTENT_TO_TOOLS.get( diff --git a/orchestrator/modules/context/sections/tools.py b/orchestrator/modules/context/sections/tools.py index f1bc2382c..2896a1156 100644 --- a/orchestrator/modules/context/sections/tools.py +++ b/orchestrator/modules/context/sections/tools.py @@ -203,6 +203,7 @@ async def _load_filtered( available_tools=all_tools, conversation_context=conversation_context, tool_hints=tool_hints, + agent_id=agent_id, ) if not result.should_include_tools: diff --git a/orchestrator/tests/test_us014_graph_router_delegation.py b/orchestrator/tests/test_us014_graph_router_delegation.py new file mode 100644 index 000000000..91f193753 --- /dev/null +++ b/orchestrator/tests/test_us014_graph_router_delegation.py @@ -0,0 +1,281 @@ +""" +PRD-141 US-014: SmartToolRouter delegates ranking to GraphRouter. +================================================================= + +The chatbot tool router no longer carries its own embedding path. When +``SEMANTIC_TOOL_ROUTING`` is on, ``route()`` delegates ranking to +``GraphRouter.rank_chains`` (the single tool-selection pipeline) and filters +``available_tools`` to the ranked names, always preserving CORE_TOOLS / +ALWAYS_INCLUDE / the classifier's suggested tools. On GraphRouter failure it +records a structured ``routing`` error and falls back to category filtering. + +``consumers.chatbot.__init__`` eagerly imports the DB-backed chat service, so +we leaf-load ``intent_classifier`` + ``smart_tool_router`` under a synthetic +package (both are stdlib-only at import time). The GraphRouter and +exception-telemetry imports live *inside* ``route()``; tests inject fakes for +them into ``sys.modules`` so no DB-backed module is ever imported. +""" +import importlib.util +import sys +import types +from pathlib import Path +from types import SimpleNamespace + +import pytest + +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + +_chatbot_dir = _orchestrator_root / "consumers" / "chatbot" +_PKG = "_us014_chatbot" + + +def _load_modules(): + """Leaf-load intent_classifier + smart_tool_router under a synthetic package.""" + if _PKG not in sys.modules: + pkg = types.ModuleType(_PKG) + pkg.__path__ = [str(_chatbot_dir)] + sys.modules[_PKG] = pkg + + def _leaf(mod_name): + full = f"{_PKG}.{mod_name}" + if full in sys.modules: + return sys.modules[full] + spec = importlib.util.spec_from_file_location(full, _chatbot_dir / f"{mod_name}.py") + module = importlib.util.module_from_spec(spec) + module.__package__ = _PKG + sys.modules[full] = module + spec.loader.exec_module(module) + return module + + intent_mod = _leaf("intent_classifier") + router_mod = _leaf("smart_tool_router") + return router_mod, intent_mod + + +_router_mod, _intent_mod = _load_modules() +SmartToolRouter = _router_mod.SmartToolRouter +ToolRoutingResult = _router_mod.ToolRoutingResult +Intent = _intent_mod.Intent +IntentResult = _intent_mod.IntentResult + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _intent_result(primary=Intent.SEARCH, requires_tools=True, suggested=None, + reasoning="stub intent"): + return IntentResult( + primary_intent=primary, + confidence=0.9, + requires_tools=requires_tools, + requires_memory=False, + suggested_tools=suggested or [], + reasoning=reasoning, + is_simple=False, + ) + + +class _FakeClassifier: + def __init__(self, result): + self._result = result + + def classify(self, query, conversation_context=None): + return self._result + + +def _tool(name): + return {"type": "function", "function": {"name": name, "description": f"{name} desc"}} + + +@pytest.fixture +def graph_env(monkeypatch): + """Force semantic routing on and inject fake GraphRouter + telemetry modules.""" + from config import config as real_config + monkeypatch.setattr(real_config, "SEMANTIC_TOOL_ROUTING", True) + + recorder = {"errors": [], "rank_calls": []} + + def install_rank_chains(fn): + fake_router = SimpleNamespace(rank_chains=fn) + fake_mod = types.ModuleType("modules.tools.discovery.graph_router") + fake_mod.get_graph_router = lambda: fake_router + monkeypatch.setitem(sys.modules, "modules.tools.discovery.graph_router", fake_mod) + + def fake_record_error(**kwargs): + recorder["errors"].append(kwargs) + + fake_tel = types.ModuleType("core.utils.exception_telemetry") + fake_tel.record_error = fake_record_error + monkeypatch.setitem(sys.modules, "core.utils.exception_telemetry", fake_tel) + + recorder["install_rank_chains"] = install_rank_chains + return recorder + + +# --------------------------------------------------------------------------- +# Deletion of the embedding path +# --------------------------------------------------------------------------- + +def test_no_embedding_manager_on_smart_router(): + """The PRD-64 embedding state + methods are gone — GraphRouter owns ranking now.""" + r = SmartToolRouter() + for attr in ( + "_embedding_manager", + "_tool_embeddings", + "_embeddings_initialized", + "_embeddings_init_lock", + "_ensure_embeddings", + "_rank_tools_by_similarity", + ): + assert not hasattr(r, attr), f"{attr} should have been deleted in US-014" + + +def test_smart_router_module_has_no_embedding_imports(): + """No residual import of the local embedding stack.""" + import inspect + src = inspect.getsource(_router_mod) + assert "core.math.vector_operations" not in src + assert "core.llm.embedding_manager" not in src + + +# --------------------------------------------------------------------------- +# Delegation +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_smart_router_delegates_to_graph_router(graph_env): + """route() ranks via GraphRouter and filters available_tools to ranked names.""" + async def ok_rank(query, agent_id=None, top_k=15, **kw): + graph_env["rank_calls"].append((query, agent_id, top_k)) + return [ + ("platform_get_system_health", 0.9, ["platform_get_system_health"]), + ("platform_browse_marketplace_agents", 0.8, ["platform_browse_marketplace_agents"]), + ] + + graph_env["install_rank_chains"](ok_rank) + + r = SmartToolRouter() + r.classifier = _FakeClassifier(_intent_result(suggested=[])) + tools = [ + _tool("platform_get_system_health"), # ranked + _tool("platform_browse_marketplace_agents"), # ranked + _tool("search_knowledge"), # CORE_TOOLS — preserved + _tool("platform_list_agents"), # ALWAYS_INCLUDE — preserved + _tool("totally_unrelated_tool"), # dropped + ] + + result = await r.route(query="how is the system doing", available_tools=tools, agent_id=42) + + names = {t["function"]["name"] for t in result.filtered_tools} + assert "platform_get_system_health" in names + assert "platform_browse_marketplace_agents" in names + assert "search_knowledge" in names + assert "platform_list_agents" in names + assert "totally_unrelated_tool" not in names + assert result.reasoning.startswith("Graph routing") + # agent_id threaded through, top_k pinned to 30 per the PRD + assert graph_env["rank_calls"] == [("how is the system doing", 42, 30)] + + +@pytest.mark.asyncio +async def test_always_include_tools_present(graph_env): + """Every ALWAYS_INCLUDE tool survives graph filtering even when unranked.""" + async def ok_rank(query, agent_id=None, top_k=15, **kw): + return [("platform_get_system_health", 0.9, ["platform_get_system_health"])] + + graph_env["install_rank_chains"](ok_rank) + + r = SmartToolRouter() + r.classifier = _FakeClassifier(_intent_result(suggested=[])) + always = sorted(SmartToolRouter.ALWAYS_INCLUDE) + tools = [_tool("platform_get_system_health"), _tool("totally_unrelated_tool")] + tools += [_tool(n) for n in always] + + result = await r.route(query="status?", available_tools=tools, agent_id=1) + + names = {t["function"]["name"] for t in result.filtered_tools} + for n in always: + assert n in names, f"ALWAYS_INCLUDE tool {n} dropped" + assert "platform_get_system_health" in names + + +@pytest.mark.asyncio +async def test_suggested_tools_preserved(graph_env): + """The classifier's suggested tools survive graph filtering even when unranked.""" + async def ok_rank(query, agent_id=None, top_k=15, **kw): + return [("platform_get_system_health", 0.9, ["platform_get_system_health"])] + + graph_env["install_rank_chains"](ok_rank) + + r = SmartToolRouter() + r.classifier = _FakeClassifier(_intent_result(suggested=["generate_document"])) + tools = [_tool("platform_get_system_health"), _tool("generate_document"), _tool("unrelated")] + + result = await r.route(query="make me a report", available_tools=tools, agent_id=5) + + names = {t["function"]["name"] for t in result.filtered_tools} + assert "generate_document" in names # suggested → preserved + assert "platform_get_system_health" in names + assert "unrelated" not in names + + +# --------------------------------------------------------------------------- +# Fallback +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_graph_router_fallback_on_failure(graph_env): + """A GraphRouter failure records a 'routing' error and falls back to category filtering.""" + async def boom_rank(query, agent_id=None, top_k=15, **kw): + raise RuntimeError("graph down") + + graph_env["install_rank_chains"](boom_rank) + + r = SmartToolRouter() + r.classifier = _FakeClassifier( + _intent_result(primary=Intent.MULTI_STEP, suggested=["search_knowledge"]) + ) + tools = [_tool("search_knowledge"), _tool("composio_execute"), _tool("totally_unrelated_tool")] + + result = await r.route(query="do a multi step thing", available_tools=tools, agent_id=7) + + # Structured error recorded under the routing subsystem (not a bare warning) + assert len(graph_env["errors"]) == 1 + err = graph_env["errors"][0] + assert err["subsystem"] == "routing" + assert err["operation"] == "graph_rank_chains" + assert err["agent_id"] == 7 + assert isinstance(err["error"], RuntimeError) + + # Fell through to category filtering — not the graph path + assert result.should_include_tools is True + assert not result.reasoning.startswith("Graph routing") + names = {t["function"]["name"] for t in result.filtered_tools} + assert "search_knowledge" in names + + +@pytest.mark.asyncio +async def test_semantic_off_skips_graph_router(graph_env, monkeypatch): + """With SEMANTIC_TOOL_ROUTING off, GraphRouter is never called.""" + from config import config as real_config + monkeypatch.setattr(real_config, "SEMANTIC_TOOL_ROUTING", False) + + called = {"n": 0} + + async def tracking_rank(query, agent_id=None, top_k=15, **kw): + called["n"] += 1 + return [("x", 1.0, ["x"])] + + graph_env["install_rank_chains"](tracking_rank) + + r = SmartToolRouter() + r.classifier = _FakeClassifier(_intent_result(primary=Intent.MULTI_STEP, suggested=["search_knowledge"])) + tools = [_tool("search_knowledge"), _tool("composio_execute")] + + result = await r.route(query="anything", available_tools=tools, agent_id=3) + + assert called["n"] == 0 # delegation gated off → straight to category filtering + assert result.should_include_tools is True From 69c9b2f41b599c369d38dd62e1a86ac97ae872f7 Mon Sep 17 00:00:00 2001 From: Gerard Kavanagh Date: Thu, 28 May 2026 22:41:01 +0100 Subject: [PATCH 7/7] refactor(routing): ActionRegistry-backed intent filtering (PRD-141 US-015) Delete the hardcoded TOOL_CATEGORIES / INTENT_TO_TOOLS class dicts and the dead _filter_tools_by_categories / _tool_matches_query helpers. The category fallback now maps intent -> ActionRegistry category names via a module-level _INTENT_TO_REGISTRY_CATEGORIES dict and pulls action names from the registry at call time, so a new action under an already-mapped category is auto-discoverable with no router edit. Kept set is unioned with the classifier's suggested tools + CORE_TOOLS + ALWAYS_INCLUDE. Registry lookup is wrapped so a registry hiccup degrades gracefully instead of crashing routing. --- .../consumers/chatbot/smart_tool_router.py | 181 +++++++---------- .../test_us015_registry_intent_filter.py | 186 ++++++++++++++++++ 2 files changed, 255 insertions(+), 112 deletions(-) create mode 100644 orchestrator/tests/test_us015_registry_intent_filter.py diff --git a/orchestrator/consumers/chatbot/smart_tool_router.py b/orchestrator/consumers/chatbot/smart_tool_router.py index e67cf4ca3..02cb60c7f 100644 --- a/orchestrator/consumers/chatbot/smart_tool_router.py +++ b/orchestrator/consumers/chatbot/smart_tool_router.py @@ -27,6 +27,30 @@ logger = logging.getLogger(__name__) +# Intent → ActionRegistry *category* names (PRD-141 US-015). +# Replaces the old hardcoded per-tool category dicts: instead of pinning literal +# tool names, each tool-requiring intent maps to one or more ActionRegistry +# categories, and the matching action names are pulled from the registry at call +# time. An action registered under an already-mapped category is therefore +# auto-discoverable with no edit to this router. +# Values are REAL registry category names (verified against modules/tools/). +# GREETING / CHITCHAT / FACTUAL are intentionally absent — they classify as +# requires_tools=False and never reach the filter. +_INTENT_TO_REGISTRY_CATEGORIES: Dict[Intent, List[str]] = { + Intent.DATA_QUERY: ["analytics", "database", "graph", "field", "reports"], + Intent.SEARCH: ["discovery", "documents", "graph", "memory", "workspace_files"], + Intent.EXTERNAL_ACTION: ["integrations", "notifications", "marketplace", "skills"], + Intent.CREATION: ["documents", "reports", "blog", "workspace_files", "playbooks"], + Intent.MEMORY_RECALL: ["memory", "field"], + Intent.MULTI_STEP: [ + "agents", "missions", "tasks", "playbooks", "workspace", "workspace_files", + "analytics", "reports", "documents", "graph", "memory", "field", + "marketplace", "skills", "monitoring", "infrastructure", "integrations", + "discovery", "scheduling", "governance", "notifications", "blog", + ], +} + + @dataclass class ToolRoutingResult: """Result of tool routing decision.""" @@ -77,58 +101,6 @@ class SmartToolRouter: "platform_field_inject", }) - # Tool categories - TOOL_CATEGORIES = { - "data": ["query_database", "smart_query_database", "sql_query"], - "search": ["search_knowledge", "semantic_search", "search_codebase", "search_multimodal", - "search_tables", "search_images", "search_formulas"], - "web_search": [ - "TAVILY_TAVILY_SEARCH", "COMPOSIO_SEARCH_FETCH_URL_CONTENT", - "COMPOSIO_SEARCH_SEC_FILINGS", "composio_execute", - ], - "files": ["workspace_read_file", "workspace_write_file", "workspace_list_dir", "workspace_grep"], - "external": ["composio_execute", "composio_actions"], - "creation": ["workspace_write_file", "generate_document"], - "document": ["generate_document", "workspace_write_file"], - "code": ["search_codebase", "execute_code", "run_command"], - # Promoted platform tool categories (PRD-122 US-010) - "platform_management": [ - "platform_list_agents", "platform_get_agent", - "platform_create_agent", "platform_update_agent", - ], - "marketplace": [ - "platform_browse_marketplace_agents", - "platform_browse_marketplace_skills", - "platform_browse_marketplace_plugins", - "platform_install_skill", "platform_install_plugin", - ], - "monitoring": [ - "platform_get_system_health", "platform_get_activity_feed", - ], - "memory": [ - "platform_search_memory", "platform_store_memory", - ], - "fields": [ - "platform_field_query", "platform_field_inject", - ], - } - - # Intent to tool category mapping - INTENT_TO_TOOLS = { - Intent.DATA_QUERY: ["data", "search", "web_search", "fields"], - Intent.SEARCH: ["search", "web_search", "code", "memory"], - Intent.EXTERNAL_ACTION: ["external", "web_search", "document", "platform_management"], - Intent.CREATION: ["files", "creation", "document", "external", "platform_management"], - Intent.MULTI_STEP: [ - "data", "search", "web_search", "files", "external", "document", "code", - "platform_management", "marketplace", "monitoring", "memory", "fields", - ], - Intent.MEMORY_RECALL: ["memory"], # Memory tools for recall intents - Intent.GREETING: [], # No tools needed - Intent.CHITCHAT: [], # No tools needed - Intent.FACTUAL: [], # Try without tools first - } - def __init__(self): self.classifier = get_intent_classifier() @@ -236,17 +208,8 @@ async def route( agent_id=agent_id, ) - # Fallback: keyword-based category matching - relevant_categories = self.INTENT_TO_TOOLS.get( - intent_result.primary_intent, - [] - ) - - filtered = self._filter_tools_by_categories( - available_tools, - relevant_categories, - intent_result.suggested_tools - ) + # Fallback: ActionRegistry category matching (PRD-141 US-015) + filtered = self._filter_tools_by_intent(available_tools, intent_result) tool_choice = self._determine_tool_choice(intent_result, filtered) priority = intent_result.suggested_tools or [] @@ -259,42 +222,56 @@ async def route( reasoning=intent_result.reasoning ) - def _filter_tools_by_categories( + def _filter_tools_by_intent( self, all_tools: List[Dict[str, Any]], - categories: List[str], - suggested: List[str] + intent_result: IntentResult, ) -> List[Dict[str, Any]]: - """Filter tools to only include relevant categories.""" + """Filter tools to the ActionRegistry categories mapped to the intent. + + PRD-141 US-015: maps the classified intent to ActionRegistry *category + names* via ``_INTENT_TO_REGISTRY_CATEGORIES`` and pulls the matching + action names from the registry at call time. The kept set is unioned + with the classifier's suggested tools plus the always-on CORE_TOOLS / + ALWAYS_INCLUDE sets, so an action registered under an already-mapped + category is auto-discoverable with no edit to this router. + """ + categories = _INTENT_TO_REGISTRY_CATEGORIES.get(intent_result.primary_intent, []) + suggested = set(intent_result.suggested_tools or []) + + # Nothing to narrow on (unmapped intent, no suggestions) — can't filter + # meaningfully, so keep the full set. if not categories and not suggested: - # For multi-step or complex queries, include all tools return all_tools - # Build set of relevant tool names - relevant_names = set(suggested) if suggested else set() - for category in categories: - relevant_names.update(self.TOOL_CATEGORIES.get(category, [])) - - # Always include core tools and promoted always-include tools - relevant_names.update(self.CORE_TOOLS) - relevant_names.update(self.ALWAYS_INCLUDE) - - # Filter - filtered = [] - for tool in all_tools: - if not isinstance(tool, dict): - continue - fn = tool.get("function", {}) - name = fn.get("name", "") - - if name in relevant_names: - filtered.append(tool) - elif self._tool_matches_query(name, fn.get("description", ""), categories): - filtered.append(tool) - - # Limit to reasonable number + relevant_names = suggested | set(self.CORE_TOOLS) | set(self.ALWAYS_INCLUDE) + + if categories: + # Lazy import — the registry pulls in heavy modules.tools.* deps. + # This is already the fallback-of-last-resort (the graph path failed + # or is off), so a registry hiccup must not crash routing: degrade + # to suggested ∪ CORE_TOOLS ∪ ALWAYS_INCLUDE. + try: + from modules.tools.discovery.action_registry import get_action_registry + registry = get_action_registry() + for category in categories: + for action in registry.get_by_category(category): + relevant_names.add(action.name) + except Exception: + logger.warning( + "[ToolRouter] ActionRegistry lookup failed during category " + "fallback — degrading to core/always/suggested tools", + exc_info=True, + ) + + filtered = [ + t for t in all_tools + if isinstance(t, dict) + and t.get("function", {}).get("name", "") in relevant_names + ] + + # Limit to a reasonable number, keeping suggested + core first. if len(filtered) > 30: - # Keep suggested + core + first N others priority_tools = [] other_tools = [] for tool in filtered: @@ -308,26 +285,6 @@ def _filter_tools_by_categories( logger.debug(f"[ToolRouter] Filtered {len(all_tools)} tools to {len(filtered)}") return filtered - def _tool_matches_query(self, name: str, description: str, categories: List[str]) -> bool: - """Check if a tool might be relevant based on its name/description.""" - text = f"{name} {description}".lower() - - category_keywords = { - "data": ["database", "query", "sql", "data", "analytics"], - "search": ["search", "find", "lookup", "knowledge"], - "files": ["file", "directory", "folder", "write", "read"], - "external": ["email", "slack", "github", "compose"], - "document": ["document", "report", "pdf", "docx", "xlsx", "invoice", "export", "generate"], - "code": ["code", "execute", "run", "command"], - } - - for category in categories: - keywords = category_keywords.get(category, []) - if any(kw in text for kw in keywords): - return True - - return False - def _determine_tool_choice( self, intent: IntentResult, diff --git a/orchestrator/tests/test_us015_registry_intent_filter.py b/orchestrator/tests/test_us015_registry_intent_filter.py new file mode 100644 index 000000000..0d9f2d37a --- /dev/null +++ b/orchestrator/tests/test_us015_registry_intent_filter.py @@ -0,0 +1,186 @@ +""" +PRD-141 US-015: intent->tool filtering reads ActionRegistry categories. +======================================================================= + +The chatbot router's category fallback no longer carries hardcoded +``TOOL_CATEGORIES`` / ``INTENT_TO_TOOLS`` dicts. ``_filter_tools_by_intent`` +maps the classified intent to ActionRegistry *category names* via +``_INTENT_TO_REGISTRY_CATEGORIES`` and pulls the matching action names from the +registry at call time — so an action registered under an already-mapped +category is auto-discoverable with no edit to the router. The kept set is +unioned with the classifier's suggested tools plus the always-on CORE_TOOLS / +ALWAYS_INCLUDE sets. + +Leaf-load pattern as in US-013/US-014: ``consumers.chatbot.__init__`` pulls the +DB-backed chat service, so we load ``intent_classifier`` + ``smart_tool_router`` +under a synthetic package and inject a fake ``action_registry`` module into +``sys.modules`` (the registry import inside ``_filter_tools_by_intent`` is lazy). +""" +import importlib.util +import sys +import types +from pathlib import Path +from types import SimpleNamespace + +import pytest + +_orchestrator_root = Path(__file__).resolve().parent.parent +if str(_orchestrator_root) not in sys.path: + sys.path.insert(0, str(_orchestrator_root)) + +_chatbot_dir = _orchestrator_root / "consumers" / "chatbot" +_PKG = "_us015_chatbot" + + +def _load_modules(): + if _PKG not in sys.modules: + pkg = types.ModuleType(_PKG) + pkg.__path__ = [str(_chatbot_dir)] + sys.modules[_PKG] = pkg + + def _leaf(mod_name): + full = f"{_PKG}.{mod_name}" + if full in sys.modules: + return sys.modules[full] + spec = importlib.util.spec_from_file_location(full, _chatbot_dir / f"{mod_name}.py") + module = importlib.util.module_from_spec(spec) + module.__package__ = _PKG + sys.modules[full] = module + spec.loader.exec_module(module) + return module + + intent_mod = _leaf("intent_classifier") + router_mod = _leaf("smart_tool_router") + return router_mod, intent_mod + + +_router_mod, _intent_mod = _load_modules() +SmartToolRouter = _router_mod.SmartToolRouter +_INTENT_TO_REGISTRY_CATEGORIES = _router_mod._INTENT_TO_REGISTRY_CATEGORIES +Intent = _intent_mod.Intent +IntentResult = _intent_mod.IntentResult + + +def _intent_result(primary=Intent.DATA_QUERY, suggested=None): + return IntentResult( + primary_intent=primary, + confidence=0.9, + requires_tools=True, + requires_memory=False, + suggested_tools=suggested or [], + reasoning="stub", + is_simple=False, + ) + + +def _tool(name): + return {"type": "function", "function": {"name": name, "description": f"{name} desc"}} + + +@pytest.fixture +def registry_env(monkeypatch): + """Inject a fake ActionRegistry whose get_by_category is test-controlled.""" + state = {"by_category": {}} + + class _FakeRegistry: + def get_by_category(self, category): + return state["by_category"].get(category, []) + + fake_mod = types.ModuleType("modules.tools.discovery.action_registry") + fake_mod.get_action_registry = lambda: _FakeRegistry() + monkeypatch.setitem(sys.modules, "modules.tools.discovery.action_registry", fake_mod) + return state + + +# --------------------------------------------------------------------------- +# The hardcoded dicts are gone +# --------------------------------------------------------------------------- + +def test_legacy_dicts_deleted(): + assert not hasattr(SmartToolRouter, "TOOL_CATEGORIES") + assert not hasattr(SmartToolRouter, "INTENT_TO_TOOLS") + assert not hasattr(SmartToolRouter, "_filter_tools_by_categories") + assert not hasattr(SmartToolRouter, "_tool_matches_query") + + +# --------------------------------------------------------------------------- +# Mapping +# --------------------------------------------------------------------------- + +def test_intent_to_registry_categories_mapping(): + """Every tool-requiring intent maps to >=1 ActionRegistry category.""" + tool_intents = [ + Intent.DATA_QUERY, + Intent.SEARCH, + Intent.EXTERNAL_ACTION, + Intent.CREATION, + Intent.MEMORY_RECALL, + Intent.MULTI_STEP, + ] + for intent in tool_intents: + cats = _INTENT_TO_REGISTRY_CATEGORIES.get(intent) + assert cats, f"{intent} must map to >=1 registry category" + assert all(len(cats) >= 1 for cats in _INTENT_TO_REGISTRY_CATEGORIES.values()) + + +# --------------------------------------------------------------------------- +# Registry-backed filtering +# --------------------------------------------------------------------------- + +def test_category_filter_uses_registry(registry_env): + """Tool names come from ActionRegistry.get_by_category for the mapped categories.""" + registry_env["by_category"]["analytics"] = [ + SimpleNamespace(name="metrics_read"), + SimpleNamespace(name="kpi_dump"), + ] + r = SmartToolRouter() + intent = _intent_result(primary=Intent.DATA_QUERY, suggested=[]) + tools = [_tool("metrics_read"), _tool("kpi_dump"), _tool("totally_unrelated_tool")] + + filtered = r._filter_tools_by_intent(tools, intent) + + names = {t["function"]["name"] for t in filtered} + assert "metrics_read" in names + assert "kpi_dump" in names + assert "totally_unrelated_tool" not in names + + +def test_new_action_auto_discoverable(registry_env): + """A new action under an already-mapped category appears with NO router code change.""" + registry_env["by_category"]["analytics"] = [SimpleNamespace(name="brand_new_kpi_action")] + r = SmartToolRouter() + intent = _intent_result(primary=Intent.DATA_QUERY, suggested=[]) + tools = [_tool("brand_new_kpi_action")] + + filtered = r._filter_tools_by_intent(tools, intent) + + assert {t["function"]["name"] for t in filtered} >= {"brand_new_kpi_action"} + + +def test_filter_always_keeps_core_and_always_include(registry_env): + """CORE_TOOLS, ALWAYS_INCLUDE, and suggested tools survive filtering.""" + registry_env["by_category"]["analytics"] = [SimpleNamespace(name="metrics_read")] + r = SmartToolRouter() + intent = _intent_result(primary=Intent.DATA_QUERY, suggested=["search_codebase"]) + core = sorted(SmartToolRouter.CORE_TOOLS) + always = sorted(SmartToolRouter.ALWAYS_INCLUDE) + tools = [_tool("metrics_read")] + [_tool(n) for n in core] + [_tool(n) for n in always] + + filtered = r._filter_tools_by_intent(tools, intent) + + names = {t["function"]["name"] for t in filtered} + for n in core + always: + assert n in names + assert "metrics_read" in names # from registry + + +def test_no_categories_no_suggested_returns_all(registry_env): + """An intent with no category mapping and no suggestions can't narrow → keep all.""" + r = SmartToolRouter() + # GREETING is not a key in _INTENT_TO_REGISTRY_CATEGORIES and has no suggestions + intent = _intent_result(primary=Intent.GREETING, suggested=[]) + tools = [_tool("a"), _tool("b")] + + filtered = r._filter_tools_by_intent(tools, intent) + + assert {t["function"]["name"] for t in filtered} == {"a", "b"}