diff --git a/mcp_gateway/config.py b/mcp_gateway/config.py index f7a18ee..bfd7caa 100644 --- a/mcp_gateway/config.py +++ b/mcp_gateway/config.py @@ -2,7 +2,10 @@ import logging import os from pathlib import Path -from typing import Dict, Any, List, Tuple +from dataclasses import dataclass +from keyword import iskeyword +import re +from typing import Dict, Any, List from mcp import types CONFIG_FILE_NAME = "mcp.json" @@ -14,6 +17,18 @@ class Constants: SERVERS = "servers" + + +@dataclass(frozen=True) +class ToolParamDescription: + name: str + python_name: str + type_annotation: Any + description: str + required: bool + + def __getitem__(self, index: int) -> Any: + return (self.name, self.type_annotation, self.description)[index] def find_config_file(mcp_json_path: str) -> Path | None: @@ -168,13 +183,15 @@ def load_config(mcp_json_path: str) -> Dict[str, Any]: return {} # Return empty dict -def get_tool_params_description(tool: types.Tool) -> List[Tuple[str, Any, str]]: +def get_tool_params_description(tool: types.Tool) -> List[ToolParamDescription]: param_signatures = [] # Tool has inputSchema (JSON Schema) instead of arguments if hasattr(tool, "inputSchema") and tool.inputSchema: # Try to extract properties from JSON Schema properties = tool.inputSchema.get("properties", {}) + required_params = set(tool.inputSchema.get("required", [])) + used_python_names = set() for param_name, param_schema in properties.items(): param_type = Any # Default type param_description = param_schema.get("description", "") @@ -192,5 +209,29 @@ def get_tool_params_description(tool: types.Tool) -> List[Tuple[str, Any, str]]: } param_type = type_mapping.get(json_type, Any) - param_signatures.append((param_name, param_type, param_description)) + param_signatures.append( + ToolParamDescription( + name=param_name, + python_name=_safe_parameter_name(param_name, used_python_names), + type_annotation=param_type, + description=param_description, + required=param_name in required_params, + ) + ) return param_signatures + + +def _safe_parameter_name(name: str, used_names: set[str]) -> str: + safe_name = re.sub(r"\W", "_", name) + if not safe_name or safe_name[0].isdigit() or iskeyword(safe_name): + safe_name = f"param_{safe_name}" + if not safe_name.isidentifier(): + safe_name = "param" + + candidate = safe_name + index = 2 + while candidate in used_names: + candidate = f"{safe_name}_{index}" + index += 1 + used_names.add(candidate) + return candidate diff --git a/mcp_gateway/gateway.py b/mcp_gateway/gateway.py index 7511f5a..1caea28 100644 --- a/mcp_gateway/gateway.py +++ b/mcp_gateway/gateway.py @@ -10,8 +10,10 @@ AsyncIterator, List, Tuple, + Annotated, ) import inspect +from pydantic import Field from mcp.server.fastmcp import FastMCP, Context from mcp import types @@ -46,7 +48,8 @@ async def register_dynamic_tool( logger.debug(f"Attempting to register dynamic tool: {dynamic_tool_name}") # Extract parameter types from the tool's inputSchema - param_signatures = get_tool_params_description(tool)# Create a properly typed dynamic function based on the original tool's signature + param_signatures = get_tool_params_description(tool) + def create_typed_handler(param_signatures): # Create parameters for the function signature parameters = [ @@ -59,16 +62,24 @@ def create_typed_handler(param_signatures): annotations = {"ctx": Context, "return": types.CallToolResult} + required_params = [param for param in param_signatures if param.required] + optional_params = [param for param in param_signatures if not param.required] + # Add parameters from the original tool - for name, type_ann, description in param_signatures: + for param in required_params + optional_params: + annotation = Annotated[ + param.type_annotation, Field(alias=param.name, description=param.description) + ] + default = inspect.Parameter.empty if param.required else None parameters.append( inspect.Parameter( - name=name, - annotation=type_ann, + name=param.python_name, + annotation=annotation, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, ) ) - annotations[name] = type_ann + annotations[param.python_name] = annotation # Create the proper signature sig = inspect.Signature(parameters=parameters) @@ -77,7 +88,11 @@ def create_typed_handler(param_signatures): async def dynamic_tool_impl(*args, **kwargs): ctx = kwargs.get("ctx", args[0] if args else None) # Remove ctx from kwargs before passing to the proxied server - tool_kwargs = {k: v for k, v in kwargs.items() if k != "ctx"} + tool_kwargs = { + param.name: value + for param in param_signatures + if (value := _get_tool_argument(kwargs, param)) is not None + } logger.info( f"Executing dynamic tool '{dynamic_tool_name}' (proxied from {server_name}/{tool.name})" @@ -141,6 +156,10 @@ async def dynamic_tool_impl(*args, **kwargs): ) +def _get_tool_argument(kwargs, param): + return kwargs.get(param.name, kwargs.get(param.python_name)) + + async def register_dynamic_prompt( gateway_mcp: FastMCP, server_name: str, @@ -619,4 +638,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_dynamic_tool_registration.py b/tests/test_dynamic_tool_registration.py new file mode 100644 index 0000000..a9f4b6d --- /dev/null +++ b/tests/test_dynamic_tool_registration.py @@ -0,0 +1,166 @@ +import inspect +import asyncio + +from mcp.server.fastmcp import FastMCP +from mcp import types + +from mcp_gateway.gateway import register_dynamic_tool + + +class RecordingFastMCP: + def __init__(self): + self.registered = {} + + def tool(self, name=None, description=None): + def decorator(fn): + self.registered[name] = fn + return fn + + return decorator + + +class RecordingProxiedServer: + def __init__(self): + self.calls = [] + + async def call_tool(self, **kwargs): + self.calls.append(kwargs) + return types.CallToolResult( + content=[types.TextContent(type="text", text="ok")] + ) + + +def test_dynamic_tool_accepts_json_schema_names_that_are_not_python_identifiers(): + asyncio.run(_check_dynamic_tool_accepts_json_schema_names_that_are_not_python_identifiers()) + + +async def _check_dynamic_tool_accepts_json_schema_names_that_are_not_python_identifiers(): + gateway_mcp = RecordingFastMCP() + proxied_server = RecordingProxiedServer() + tool = types.Tool( + name="api-get-block-children", + description="Fetch children", + inputSchema={ + "type": "object", + "properties": { + "Notion-Version": {"type": "string"}, + "1st-page": {"type": "integer"}, + "class": {"type": "string"}, + }, + "required": ["Notion-Version"], + }, + ) + + await register_dynamic_tool( + gateway_mcp, + "notion", + tool, + proxied_server, + plugin_manager=None, + ) + + handler = gateway_mcp.registered["notion_api-get-block-children"] + signature = inspect.signature(handler) + + assert "Notion_Version" in signature.parameters + assert "param_1st_page" in signature.parameters + assert "param_class" in signature.parameters + assert signature.parameters["Notion_Version"].default is inspect.Parameter.empty + assert signature.parameters["param_1st_page"].default is None + assert signature.parameters["param_class"].default is None + + await handler( + Notion_Version="2025-06-20", + param_1st_page=3, + param_class="page", + ) + + assert proxied_server.calls[0]["name"] == "api-get-block-children" + assert proxied_server.calls[0]["arguments"] == { + "Notion-Version": "2025-06-20", + "1st-page": 3, + "class": "page", + } + + +def test_dynamic_tool_omits_unset_optional_arguments(): + asyncio.run(_check_dynamic_tool_omits_unset_optional_arguments()) + + +async def _check_dynamic_tool_omits_unset_optional_arguments(): + gateway_mcp = RecordingFastMCP() + proxied_server = RecordingProxiedServer() + tool = types.Tool( + name="search", + description="Search", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "start-cursor": {"type": "string"}, + }, + "required": ["query"], + }, + ) + + await register_dynamic_tool( + gateway_mcp, + "notion", + tool, + proxied_server, + plugin_manager=None, + ) + + handler = gateway_mcp.registered["notion_search"] + await handler(query="blocks", start_cursor=None) + + assert proxied_server.calls[0]["arguments"] == {"query": "blocks"} + + +def test_dynamic_tool_schema_uses_original_json_schema_names(): + asyncio.run(_check_dynamic_tool_schema_uses_original_json_schema_names()) + + +async def _check_dynamic_tool_schema_uses_original_json_schema_names(): + gateway_mcp = FastMCP("test") + proxied_server = RecordingProxiedServer() + tool = types.Tool( + name="api-get-block-children", + description="Fetch children", + inputSchema={ + "type": "object", + "properties": { + "Notion-Version": {"type": "string"}, + "start-cursor": {"type": "string"}, + }, + "required": ["Notion-Version"], + }, + ) + + await register_dynamic_tool( + gateway_mcp, + "notion", + tool, + proxied_server, + plugin_manager=None, + ) + + tools = await gateway_mcp.list_tools() + registered_tool = next( + tool for tool in tools if tool.name == "notion_api-get-block-children" + ) + + assert registered_tool.inputSchema["properties"].keys() == { + "Notion-Version", + "start-cursor", + } + assert registered_tool.inputSchema["required"] == ["Notion-Version"] + + await gateway_mcp.call_tool( + "notion_api-get-block-children", + {"Notion-Version": "2025-06-20"}, + ) + + assert proxied_server.calls[0]["arguments"] == { + "Notion-Version": "2025-06-20" + }