diff --git a/python/samples/demos/mcp_server/README.md b/python/samples/demos/mcp_server/README.md index 1eafdff3b4f9..1aef1595c1d2 100644 --- a/python/samples/demos/mcp_server/README.md +++ b/python/samples/demos/mcp_server/README.md @@ -61,6 +61,20 @@ uv --directory=/semantic-kernel/python/samples/demos/mcp_ser This will start a server that listens for incoming requests on port `8000`. +> [!NOTE] +> By default the SSE server binds to `127.0.0.1` (loopback) and only accepts requests +> with a loopback `Host` header and, when present, a loopback `Origin` header. A local +> MCP server exposes tools, plugins and model providers backed by your own credentials, +> so it is good practice to keep it reachable only from your own machine. The +> [MCP specification](https://modelcontextprotocol.io/) recommends validating `Origin` +> and binding to loopback, in part to guard against [DNS rebinding](https://en.wikipedia.org/wiki/DNS_rebinding). +> +> You can override the bind address with `--host`, e.g. `--host 0.0.0.0` to expose the +> server on the network. Do this only on a trusted network. The bundled Host/Origin +> checks only allow loopback callers, so a non-loopback deployment needs proper +> authentication - see the [`mcp_with_oauth`](../mcp_with_oauth/) sample for the +> authenticated, Streamable-HTTP pattern recommended for production. + --- In both cases, `uv` will ensure that `semantic-kernel` is installed with the `mcp` extra in a temporary virtual environment. diff --git a/python/samples/demos/mcp_server/agent_as_server.py b/python/samples/demos/mcp_server/agent_as_server.py index 3cbc012d5652..af9b2e2be6b2 100644 --- a/python/samples/demos/mcp_server/agent_as_server.py +++ b/python/samples/demos/mcp_server/agent_as_server.py @@ -5,6 +5,7 @@ # /// # Copyright (c) Microsoft. All rights reserved. import argparse +import ipaddress import logging from typing import Annotated, Any, Literal @@ -51,6 +52,16 @@ """ +def is_loopback_host(host: str) -> bool: + """Return True if the host refers to a loopback interface (incl. IPv6 ::1).""" + if host == "localhost": + return True + try: + return ipaddress.ip_address(host).is_loopback + except ValueError: + return False + + def parse_arguments(): parser = argparse.ArgumentParser(description="Run the Semantic Kernel MCP server.") parser.add_argument( @@ -66,7 +77,20 @@ def parse_arguments(): default=None, help="Port to use for SSE transport (required if transport is 'sse').", ) - return parser.parse_args() + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help=( + "Host/interface to bind the SSE server to (default: 127.0.0.1). " + "Binding to anything other than loopback (e.g. 0.0.0.0) exposes the server " + "to the network and should only be done on a trusted network with authentication added." + ), + ) + args = parser.parse_args() + if args.transport == "sse" and args.port is None: + parser.error("--port is required when --transport is 'sse'.") + return args # Define a simple plugin for the sample @@ -88,7 +112,7 @@ def get_item_price( return "$9.99" -async def run(transport: Literal["sse", "stdio"] = "stdio", port: int | None = None) -> None: +async def run(transport: Literal["sse", "stdio"] = "stdio", port: int | None = None, host: str = "127.0.0.1") -> None: async with ( # 1. Login to Azure and create a Azure AI Project Client AzureCliCredential() as creds, @@ -110,7 +134,53 @@ async def run(transport: Literal["sse", "stdio"] = "stdio", port: int | None = N import uvicorn from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.middleware import Middleware + from starlette.middleware.trustedhost import TrustedHostMiddleware + from starlette.responses import PlainTextResponse from starlette.routing import Mount, Route + from starlette.types import ASGIApp, Receive, Scope, Send + + # A local MCP server is a security boundary, not a generic web server: it exposes + # tools, plugins and model providers backed by the developer's credentials. Without + # Host/Origin validation a malicious web page could use DNS rebinding to reach this + # loopback listener from the victim's browser and invoke the exposed MCP tools. + # The MCP spec therefore requires servers to validate Origin and bind to loopback. + allowed_hosts = [ + "localhost", + "127.0.0.1", + "[::1]", + f"localhost:{port}", + f"127.0.0.1:{port}", + f"[::1]:{port}", + ] + allowed_origins = { + "http://localhost", + "http://127.0.0.1", + "http://[::1]", + f"http://localhost:{port}", + f"http://127.0.0.1:{port}", + f"http://[::1]:{port}", + } + + class OriginValidationMiddleware: + """Reject requests with an untrusted Origin header (DNS-rebinding defense).""" + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + origin = dict(scope["headers"]).get(b"origin") + if origin is not None: + try: + origin_value = origin.decode("ascii") + except UnicodeDecodeError: + origin_value = None + if origin_value not in allowed_origins: + response = PlainTextResponse("Forbidden: invalid Origin header", status_code=403) + await response(scope, receive, send) + return + await self.app(scope, receive, send) sse = SseServerTransport("/messages/") @@ -122,14 +192,27 @@ async def handle_sse(request): await server.run(read_stream, write_stream, server.create_initialization_options()) starlette_app = Starlette( - debug=True, + debug=False, routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], + middleware=[ + Middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts), + Middleware(OriginValidationMiddleware), + ], ) + + if not is_loopback_host(host): + logger.warning( + "Binding the MCP SSE server to %s exposes it beyond loopback. The bundled Host/Origin " + "checks only allow loopback callers; for a network-reachable or credentialed deployment " + "add proper authentication (see the mcp_with_oauth sample) before doing this.", + host, + ) + nest_asyncio.apply() - uvicorn.run(starlette_app, host="0.0.0.0", port=port) # nosec + uvicorn.run(starlette_app, host=host, port=port) # nosec elif transport == "stdio": from mcp.server.stdio import stdio_server @@ -142,4 +225,4 @@ async def handle_stdin(stdin: Any | None = None, stdout: Any | None = None) -> N if __name__ == "__main__": args = parse_arguments() - anyio.run(run, args.transport, args.port) + anyio.run(run, args.transport, args.port, args.host) diff --git a/python/samples/demos/mcp_server/sk_mcp_server.py b/python/samples/demos/mcp_server/sk_mcp_server.py index 6b7b617d19df..bade18f4f301 100644 --- a/python/samples/demos/mcp_server/sk_mcp_server.py +++ b/python/samples/demos/mcp_server/sk_mcp_server.py @@ -5,6 +5,7 @@ # /// # Copyright (c) Microsoft. All rights reserved. import argparse +import ipaddress import logging from typing import Any, Literal @@ -54,6 +55,16 @@ """ +def is_loopback_host(host: str) -> bool: + """Return True if the host refers to a loopback interface (incl. IPv6 ::1).""" + if host == "localhost": + return True + try: + return ipaddress.ip_address(host).is_loopback + except ValueError: + return False + + def parse_arguments(): parser = argparse.ArgumentParser(description="Run the Semantic Kernel MCP server.") parser.add_argument( @@ -69,10 +80,23 @@ def parse_arguments(): default=None, help="Port to use for SSE transport (required if transport is 'sse').", ) - return parser.parse_args() + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help=( + "Host/interface to bind the SSE server to (default: 127.0.0.1). " + "Binding to anything other than loopback (e.g. 0.0.0.0) exposes the server " + "to the network and should only be done on a trusted network with authentication added." + ), + ) + args = parser.parse_args() + if args.transport == "sse" and args.port is None: + parser.error("--port is required when --transport is 'sse'.") + return args -def run(transport: Literal["sse", "stdio"] = "stdio", port: int | None = None) -> None: +def run(transport: Literal["sse", "stdio"] = "stdio", port: int | None = None, host: str = "127.0.0.1") -> None: kernel = Kernel() @kernel_function() @@ -112,7 +136,53 @@ def echo_function(message: str, extra: str = "") -> str: import uvicorn from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.middleware import Middleware + from starlette.middleware.trustedhost import TrustedHostMiddleware + from starlette.responses import PlainTextResponse from starlette.routing import Mount, Route + from starlette.types import ASGIApp, Receive, Scope, Send + + # A local MCP server is a security boundary, not a generic web server: it exposes + # tools, plugins and model providers backed by the developer's credentials. Without + # Host/Origin validation a malicious web page could use DNS rebinding to reach this + # loopback listener from the victim's browser and invoke the exposed MCP tools. + # The MCP spec therefore requires servers to validate Origin and bind to loopback. + allowed_hosts = [ + "localhost", + "127.0.0.1", + "[::1]", + f"localhost:{port}", + f"127.0.0.1:{port}", + f"[::1]:{port}", + ] + allowed_origins = { + "http://localhost", + "http://127.0.0.1", + "http://[::1]", + f"http://localhost:{port}", + f"http://127.0.0.1:{port}", + f"http://[::1]:{port}", + } + + class OriginValidationMiddleware: + """Reject requests with an untrusted Origin header (DNS-rebinding defense).""" + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + origin = dict(scope["headers"]).get(b"origin") + if origin is not None: + try: + origin_value = origin.decode("ascii") + except UnicodeDecodeError: + origin_value = None + if origin_value not in allowed_origins: + response = PlainTextResponse("Forbidden: invalid Origin header", status_code=403) + await response(scope, receive, send) + return + await self.app(scope, receive, send) sse = SseServerTransport("/messages/") @@ -121,14 +191,26 @@ async def handle_sse(request): await server.run(read_stream, write_stream, server.create_initialization_options()) starlette_app = Starlette( - debug=True, + debug=False, routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], + middleware=[ + Middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts), + Middleware(OriginValidationMiddleware), + ], ) - uvicorn.run(starlette_app, host="0.0.0.0", port=port) # nosec + if not is_loopback_host(host): + logger.warning( + "Binding the MCP SSE server to %s exposes it beyond loopback. The bundled Host/Origin " + "checks only allow loopback callers; for a network-reachable or credentialed deployment " + "add proper authentication (see the mcp_with_oauth sample) before doing this.", + host, + ) + + uvicorn.run(starlette_app, host=host, port=port) # nosec elif transport == "stdio": import anyio from mcp.server.stdio import stdio_server @@ -142,4 +224,4 @@ async def handle_stdin(stdin: Any | None = None, stdout: Any | None = None) -> N if __name__ == "__main__": args = parse_arguments() - run(transport=args.transport, port=args.port) + run(transport=args.transport, port=args.port, host=args.host)