diff --git a/code_review_graph/main.py b/code_review_graph/main.py index 733336e..94c5f43 100644 --- a/code_review_graph/main.py +++ b/code_review_graph/main.py @@ -942,10 +942,14 @@ def _apply_tool_filter(tools: str | None = None) -> None: allowed = {t.strip() for t in raw.split(",") if t.strip()} if not allowed: return - registered = list(mcp._tool_manager._tools.keys()) + registered = [ + k.split(":")[1].split("@")[0] + for k in mcp.local_provider._components + if k.startswith("tool:") + ] for name in registered: if name not in allowed: - mcp.remove_tool(name) + mcp.local_provider.remove_tool(name) diff --git a/tests/test_main.py b/tests/test_main.py index 61fbbf4..2d17488 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -225,63 +225,65 @@ class TestApplyToolFilter: def _restore_tools(self): """Snapshot registered tools before test, restore after. - _apply_tool_filter calls ``mcp.remove_tool()`` which is + _apply_tool_filter calls ``local_provider.remove_tool()`` which is permanent. We restore by re-adding from the saved snapshot. """ - original = dict(crg_main.mcp._tool_manager._tools) + original = dict(crg_main.mcp.local_provider._components) yield - crg_main.mcp._tool_manager._tools.clear() - crg_main.mcp._tool_manager._tools.update(original) + crg_main.mcp.local_provider._components.clear() + crg_main.mcp.local_provider._components.update(original) @pytest.fixture(autouse=True) def _clean_env(self, monkeypatch): """Ensure CRG_TOOLS is not set from the outer environment.""" monkeypatch.delenv("CRG_TOOLS", raising=False) - @pytest.mark.asyncio - async def test_no_filter_keeps_all_tools(self): + def _tool_names(self) -> set[str]: + """Return current registered tool names without FastMCP internals.""" + return { + k.split(":")[1].split("@")[0] + for k in crg_main.mcp.local_provider._components + if k.startswith("tool:") + } + + def test_no_filter_keeps_all_tools(self): """When neither --tools nor CRG_TOOLS is set, all tools remain.""" - before = set((await crg_main.mcp.get_tools()).keys()) + before = self._tool_names() crg_main._apply_tool_filter(None) - after = set((await crg_main.mcp.get_tools()).keys()) + after = self._tool_names() assert before == after - @pytest.mark.asyncio - async def test_filter_via_argument(self): + def test_filter_via_argument(self): """The ``tools`` argument keeps only the listed tools.""" keep = "query_graph_tool,semantic_search_nodes_tool" crg_main._apply_tool_filter(keep) - remaining = set((await crg_main.mcp.get_tools()).keys()) + remaining = self._tool_names() assert remaining == {"query_graph_tool", "semantic_search_nodes_tool"} - @pytest.mark.asyncio - async def test_filter_via_env_var(self, monkeypatch): + def test_filter_via_env_var(self, monkeypatch): """The ``CRG_TOOLS`` env var works as fallback.""" monkeypatch.setenv("CRG_TOOLS", "query_graph_tool") crg_main._apply_tool_filter(None) - remaining = set((await crg_main.mcp.get_tools()).keys()) + remaining = self._tool_names() assert remaining == {"query_graph_tool"} - @pytest.mark.asyncio - async def test_argument_takes_precedence_over_env(self, monkeypatch): + def test_argument_takes_precedence_over_env(self, monkeypatch): """CLI --tools wins over CRG_TOOLS env var.""" monkeypatch.setenv("CRG_TOOLS", "list_repos_tool") crg_main._apply_tool_filter("query_graph_tool") - remaining = set((await crg_main.mcp.get_tools()).keys()) + remaining = self._tool_names() assert remaining == {"query_graph_tool"} - @pytest.mark.asyncio - async def test_empty_string_is_noop(self): + def test_empty_string_is_noop(self): """An empty string should not remove all tools.""" - before = set((await crg_main.mcp.get_tools()).keys()) + before = self._tool_names() crg_main._apply_tool_filter("") - after = set((await crg_main.mcp.get_tools()).keys()) + after = self._tool_names() assert before == after - @pytest.mark.asyncio - async def test_whitespace_handling(self): + def test_whitespace_handling(self): """Spaces around tool names are stripped.""" crg_main._apply_tool_filter(" query_graph_tool , semantic_search_nodes_tool ") - remaining = set((await crg_main.mcp.get_tools()).keys()) + remaining = self._tool_names() assert remaining == {"query_graph_tool", "semantic_search_nodes_tool"}