diff --git a/README.md b/README.md index a5e2e8f..141402f 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ Environment variables (prefixed with `MF_`) override CLI flags. See `.env.exampl - `MF_TRANSPORT` / `-t`: `stdio` (default) or `http` - `MF_STDIO_COMMAND` / `MF_STDIO_ARGS`: upstream binary + args +- `MF_STDIO_ENV` / `--stdio-env`: environment variables for stdio subprocess (`KEY=value;ANOTHER=value` or repeatable `--stdio-env KEY=value`) - `MF_HTTP_URL` / `MF_HTTP_HEADERS`: SSE/HTTP endpoint and extra headers (`key=value;Another=Value`) - `MF_ALLOW_TOOLS` / `-a`: exact tool names (repeatable, or comma-separated) - `MF_ALLOW_PATTERNS`: regex patterns for tool names (repeatable, or comma-separated) diff --git a/src/mcp_filter/cli.py b/src/mcp_filter/cli.py index 8d51fee..e55e7a8 100644 --- a/src/mcp_filter/cli.py +++ b/src/mcp_filter/cli.py @@ -60,6 +60,11 @@ def run( "--stdio-arg", help="Additional argument(s) for the stdio command (repeatable). Can be individual args or a quoted string that will be split.", ), + stdio_env: Optional[List[str]] = typer.Option( + None, + "--stdio-env", + help="Environment variable for the stdio command as KEY=VALUE (repeatable).", + ), http_url: Optional[str] = typer.Option( None, "--http-url", help="HTTP/SSE endpoint for the upstream MCP server." ), @@ -114,6 +119,7 @@ def run( transport=transport.lower() if transport else None, stdio_command=stdio_command, stdio_args=_parse_stdio_args(stdio_args), + stdio_env=_parse_env(stdio_env), http_url=http_url, http_headers=_parse_headers(http_headers), allow_tools=allow_tools, @@ -155,6 +161,19 @@ def _parse_headers(values: Optional[List[str]]) -> Optional[Dict[str, str]]: return headers +def _parse_env(values: Optional[List[str]]) -> Optional[Dict[str, str]]: + """Parse environment variables from KEY=VALUE format.""" + if not values: + return None + env_vars: Dict[str, str] = {} + for item in values: + if "=" not in item: + raise ConfigError(f"Environment variable '{item}' must be in KEY=VALUE format.") + key, value = item.split("=", 1) + env_vars[key.strip()] = value.strip() + return env_vars + + def _parse_stdio_args(values: Optional[List[str]]) -> Optional[List[str]]: """Parse stdio args, splitting quoted strings if needed.""" if not values: diff --git a/src/mcp_filter/config.py b/src/mcp_filter/config.py index 272474e..20c84b5 100644 --- a/src/mcp_filter/config.py +++ b/src/mcp_filter/config.py @@ -49,6 +49,7 @@ class UpstreamConfig(BaseModel): transport: Transport = "stdio" stdio_command: Optional[str] = None stdio_args: List[str] = Field(default_factory=list) + stdio_env: Dict[str, str] = Field(default_factory=dict) http_url: Optional[AnyUrl] = None http_headers: Dict[str, str] = Field(default_factory=dict) @@ -79,6 +80,7 @@ class ConfigOverrides: transport: Optional[Transport] = None stdio_command: Optional[str] = None stdio_args: Optional[List[str]] = None + stdio_env: Optional[Dict[str, str]] = None http_url: Optional[str] = None http_headers: Optional[Dict[str, str]] = None allow_tools: Optional[List[str]] = None @@ -105,6 +107,8 @@ def as_dict(self) -> Dict[str, Any]: upstream["stdio_command"] = self.stdio_command if self.stdio_args is not None: upstream["stdio_args"] = self.stdio_args + if self.stdio_env is not None: + upstream["stdio_env"] = self.stdio_env if self.http_url is not None: upstream["http_url"] = self.http_url if self.http_headers is not None: @@ -179,6 +183,9 @@ def _load_from_env(env: Mapping[str, str]) -> Dict[str, Any]: stdio_args = env.get("MF_STDIO_ARGS") if stdio_args: upstream["stdio_args"] = shlex.split(stdio_args) + stdio_env_raw = env.get("MF_STDIO_ENV") + if stdio_env_raw: + upstream["stdio_env"] = _parse_env_vars(stdio_env_raw) http_url = env.get("MF_HTTP_URL") if http_url: @@ -231,6 +238,21 @@ def _parse_headers(value: str) -> Dict[str, str]: return headers +def _parse_env_vars(value: str) -> Dict[str, str]: + """Parse environment variables from semicolon-separated key=value pairs.""" + env_vars: Dict[str, str] = {} + for item in value.split(";"): + if not item.strip(): + continue + if "=" not in item: + raise ConfigError( + f"Environment variable '{item}' must be in key=value form (separated by ';')." + ) + key, val = item.split("=", 1) + env_vars[key.strip()] = val.strip() + return env_vars + + def _to_bool(value: str) -> bool: truthy = {"1", "true", "t", "yes", "y", "on"} falsy = {"0", "false", "f", "no", "n", "off"} diff --git a/src/mcp_filter/upstream.py b/src/mcp_filter/upstream.py index 64a0453..b091b8d 100644 --- a/src/mcp_filter/upstream.py +++ b/src/mcp_filter/upstream.py @@ -96,7 +96,7 @@ async def make_upstream(cfg: UpstreamConfig) -> Upstream: if cfg.transport == "stdio": if not cfg.stdio_command: raise ConfigError("stdio transport requires a command to spawn.") - client = await _connect_stdio(fastmcp, cfg.stdio_command, cfg.stdio_args) + client = await _connect_stdio(fastmcp, cfg.stdio_command, cfg.stdio_args, cfg.stdio_env) elif cfg.transport == "http": if not cfg.http_url: raise ConfigError("http transport requires an http_url.") @@ -107,7 +107,7 @@ async def make_upstream(cfg: UpstreamConfig) -> Upstream: return _FastMCPUpstream(client) -async def _connect_stdio(fastmcp: Any, command: str, args: Optional[List[str]]) -> Any: +async def _connect_stdio(fastmcp: Any, command: str, args: Optional[List[str]], env: Optional[Dict[str, str]] = None) -> Any: args = args or [] # Try modern FastMCP (>= 2.0) with Client + StdioTransport @@ -122,7 +122,7 @@ async def _connect_stdio(fastmcp: Any, command: str, args: Optional[List[str]]) raise ConfigError("npx transport requires a package name") package = args[0] package_args = args[1:] if len(args) > 1 else [] - transport = NpxStdioTransport(package=package, args=package_args) + transport = NpxStdioTransport(package=package, args=package_args, env_vars=env) elif command.endswith(".py") or command == "python": # python script.py args -> PythonStdioTransport(script, args) if command == "python" and args: @@ -131,17 +131,17 @@ async def _connect_stdio(fastmcp: Any, command: str, args: Optional[List[str]]) else: script = command script_args = args - transport = PythonStdioTransport(script_path=script, args=script_args) + transport = PythonStdioTransport(script_path=script, args=script_args, env=env) else: # Generic command -> try importing generic StdioTransport or NodeStdioTransport from fastmcp.client import StdioTransport # StdioTransport might not accept command directly, let's check NodeStdioTransport try: from fastmcp.client import NodeStdioTransport - transport = NodeStdioTransport(command=command, args=args) + transport = NodeStdioTransport(command=command, args=args, env=env) except (ImportError, TypeError): # Fallback to generic if available - transport = StdioTransport(command=command, args=args) + transport = StdioTransport(command=command, args=args, env=env) client = Client(transport)