Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions code_review_graph/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,10 +942,14 @@ def _apply_tool_filter(tools: str | None = None) -> None:
allowed = {t.strip() for t in raw.split(",") if t.strip()}
if not allowed:
return
registered = list(mcp._tool_manager._tools.keys())
registered = [
k.split(":")[1].split("@")[0]
for k in mcp.local_provider._components
if k.startswith("tool:")
]
for name in registered:
if name not in allowed:
mcp.remove_tool(name)
mcp.local_provider.remove_tool(name)



Expand Down
50 changes: 26 additions & 24 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,63 +225,65 @@ class TestApplyToolFilter:
def _restore_tools(self):
"""Snapshot registered tools before test, restore after.

_apply_tool_filter calls ``mcp.remove_tool()`` which is
_apply_tool_filter calls ``local_provider.remove_tool()`` which is
permanent. We restore by re-adding from the saved snapshot.
"""
original = dict(crg_main.mcp._tool_manager._tools)
original = dict(crg_main.mcp.local_provider._components)
yield
crg_main.mcp._tool_manager._tools.clear()
crg_main.mcp._tool_manager._tools.update(original)
crg_main.mcp.local_provider._components.clear()
crg_main.mcp.local_provider._components.update(original)

@pytest.fixture(autouse=True)
def _clean_env(self, monkeypatch):
"""Ensure CRG_TOOLS is not set from the outer environment."""
monkeypatch.delenv("CRG_TOOLS", raising=False)

@pytest.mark.asyncio
async def test_no_filter_keeps_all_tools(self):
def _tool_names(self) -> set[str]:
"""Return current registered tool names without FastMCP internals."""
return {
k.split(":")[1].split("@")[0]
for k in crg_main.mcp.local_provider._components
if k.startswith("tool:")
}

def test_no_filter_keeps_all_tools(self):
"""When neither --tools nor CRG_TOOLS is set, all tools remain."""
before = set((await crg_main.mcp.get_tools()).keys())
before = self._tool_names()
crg_main._apply_tool_filter(None)
after = set((await crg_main.mcp.get_tools()).keys())
after = self._tool_names()
assert before == after

@pytest.mark.asyncio
async def test_filter_via_argument(self):
def test_filter_via_argument(self):
"""The ``tools`` argument keeps only the listed tools."""
keep = "query_graph_tool,semantic_search_nodes_tool"
crg_main._apply_tool_filter(keep)
remaining = set((await crg_main.mcp.get_tools()).keys())
remaining = self._tool_names()
assert remaining == {"query_graph_tool", "semantic_search_nodes_tool"}

@pytest.mark.asyncio
async def test_filter_via_env_var(self, monkeypatch):
def test_filter_via_env_var(self, monkeypatch):
"""The ``CRG_TOOLS`` env var works as fallback."""
monkeypatch.setenv("CRG_TOOLS", "query_graph_tool")
crg_main._apply_tool_filter(None)
remaining = set((await crg_main.mcp.get_tools()).keys())
remaining = self._tool_names()
assert remaining == {"query_graph_tool"}

@pytest.mark.asyncio
async def test_argument_takes_precedence_over_env(self, monkeypatch):
def test_argument_takes_precedence_over_env(self, monkeypatch):
"""CLI --tools wins over CRG_TOOLS env var."""
monkeypatch.setenv("CRG_TOOLS", "list_repos_tool")
crg_main._apply_tool_filter("query_graph_tool")
remaining = set((await crg_main.mcp.get_tools()).keys())
remaining = self._tool_names()
assert remaining == {"query_graph_tool"}

@pytest.mark.asyncio
async def test_empty_string_is_noop(self):
def test_empty_string_is_noop(self):
"""An empty string should not remove all tools."""
before = set((await crg_main.mcp.get_tools()).keys())
before = self._tool_names()
crg_main._apply_tool_filter("")
after = set((await crg_main.mcp.get_tools()).keys())
after = self._tool_names()
assert before == after

@pytest.mark.asyncio
async def test_whitespace_handling(self):
def test_whitespace_handling(self):
"""Spaces around tool names are stripped."""
crg_main._apply_tool_filter(" query_graph_tool , semantic_search_nodes_tool ")
remaining = set((await crg_main.mcp.get_tools()).keys())
remaining = self._tool_names()
assert remaining == {"query_graph_tool", "semantic_search_nodes_tool"}

Loading