diff --git a/README.md b/README.md index 874983b53..f2470c6f4 100644 --- a/README.md +++ b/README.md @@ -269,7 +269,9 @@ Resources are how you expose data to LLMs. They're similar to GET endpoints in a ```python -from mcp.server.fastmcp import FastMCP +from typing import Annotated + +from mcp.server.fastmcp import FastMCP, Path, Query mcp = FastMCP(name="Resource Example") @@ -289,6 +291,118 @@ def get_settings() -> str: "language": "en", "debug": false }""" + + +# Form-style query expansion examples using RFC 6570 URI templates + + +@mcp.resource("articles://{article_id}/view") +def view_article(article_id: str, format: str = "html", lang: str = "en") -> str: + """View an article with optional format and language selection. + + Example URIs: + - articles://123/view (uses defaults: format=html, lang=en) + - articles://123/view?format=pdf (format=pdf, lang=en) + - articles://123/view?format=pdf&lang=fr (format=pdf, lang=fr) + """ + if format == "pdf": + content = f"PDF content for article {article_id} in {lang}" + elif format == "json": + content = f'{{"article_id": "{article_id}", "content": "...", "lang": "{lang}"}}' + else: + content = f"Article {article_id} in {lang}" + + return content + + +@mcp.resource("search://query/{search_term}") +def search_content( + search_term: str, page: int = 1, limit: int = 10, category: str = "all", sort: str = "relevance" +) -> str: + """Search content with optional pagination and filtering. + + Example URIs: + - search://query/python (basic search) + - search://query/python?page=2&limit=20 (pagination) + - search://query/python?category=tutorial&sort=date (filtering) + """ + offset = (page - 1) * limit + results = f"Search results for '{search_term}' (category: {category}, sort: {sort})" + results += f"\nShowing {limit} results starting from {offset + 1}" + + # Simulated search results + for i in range(limit): + result_num = offset + i + 1 + results += f"\n{result_num}. Result about {search_term} in {category}" + + return results + + +@mcp.resource("users://{user_id}/profile") +def get_user_profile(user_id: str, include_private: bool = False, format: str = "summary") -> str: + """Get user profile with optional private data and format selection. + + Example URIs: + - users://123/profile (public data, summary format) + - users://123/profile?include_private=true (includes private data) + - users://123/profile?format=detailed&include_private=true (detailed with private) + """ + from typing import Any + + profile_data: dict[str, Any] = {"user_id": user_id, "name": "John Doe", "public_bio": "Software developer"} + + if include_private: + profile_data.update({"email": "john@example.com", "phone": "+1234567890"}) + + if format == "detailed": + profile_data.update({"last_active": "2024-01-20", "preferences": {"notifications": True}}) + + return str(profile_data) + + +@mcp.resource("api://weather/{location}") +def get_weather_data( + location: str, units: str = "metric", lang: str = "en", include_forecast: bool = False, days: int = 5 +) -> str: + """Get weather data with customizable options. + + Example URIs: + - api://weather/london (basic weather) + - api://weather/london?units=imperial&lang=es (different units and language) + - api://weather/london?include_forecast=true&days=7 (with 7-day forecast) + """ + temp_unit = "C" if units == "metric" else "F" + base_temp = 22 if units == "metric" else 72 + + weather_info = f"Weather for {location}: {base_temp}{temp_unit}" + + if include_forecast: + weather_info += f"\n{days}-day forecast:" + for day in range(1, days + 1): + forecast_temp = base_temp + (day % 3) + weather_info += f"\nDay {day}: {forecast_temp}{temp_unit}" + + return weather_info + + +@mcp.resource("api://data/{user_id}/{region}/{city}/{file_path:path}") +def resource_fn( + # Path parameters + user_id: Annotated[int, Path(gt=0, description="User ID")], # explicit Path + region, # inferred path # type: ignore + city: str, # inferred path + file_path: str, # inferred path {file_path:path} + # Required query parameter (no default) + version: int, + # Optional query parameters (defaults or Query(...)) + format: Annotated[str, Query("json", description="Output format")], + include_metadata: bool = False, + tags: list[str] = [], + lang: str = "en", + debug: bool = False, + precision: float = 0.5, +) -> str: + return f"{user_id}/{region}/{city}/{file_path}" ``` _Full example: [examples/snippets/servers/basic_resource.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_resource.py)_ diff --git a/examples/snippets/servers/basic_resource.py b/examples/snippets/servers/basic_resource.py index 5c1973059..0f2b76731 100644 --- a/examples/snippets/servers/basic_resource.py +++ b/examples/snippets/servers/basic_resource.py @@ -1,4 +1,6 @@ -from mcp.server.fastmcp import FastMCP +from typing import Annotated + +from mcp.server.fastmcp import FastMCP, Path, Query mcp = FastMCP(name="Resource Example") @@ -18,3 +20,115 @@ def get_settings() -> str: "language": "en", "debug": false }""" + + +# Form-style query expansion examples using RFC 6570 URI templates + + +@mcp.resource("articles://{article_id}/view") +def view_article(article_id: str, format: str = "html", lang: str = "en") -> str: + """View an article with optional format and language selection. + + Example URIs: + - articles://123/view (uses defaults: format=html, lang=en) + - articles://123/view?format=pdf (format=pdf, lang=en) + - articles://123/view?format=pdf&lang=fr (format=pdf, lang=fr) + """ + if format == "pdf": + content = f"PDF content for article {article_id} in {lang}" + elif format == "json": + content = f'{{"article_id": "{article_id}", "content": "...", "lang": "{lang}"}}' + else: + content = f"Article {article_id} in {lang}" + + return content + + +@mcp.resource("search://query/{search_term}") +def search_content( + search_term: str, page: int = 1, limit: int = 10, category: str = "all", sort: str = "relevance" +) -> str: + """Search content with optional pagination and filtering. + + Example URIs: + - search://query/python (basic search) + - search://query/python?page=2&limit=20 (pagination) + - search://query/python?category=tutorial&sort=date (filtering) + """ + offset = (page - 1) * limit + results = f"Search results for '{search_term}' (category: {category}, sort: {sort})" + results += f"\nShowing {limit} results starting from {offset + 1}" + + # Simulated search results + for i in range(limit): + result_num = offset + i + 1 + results += f"\n{result_num}. Result about {search_term} in {category}" + + return results + + +@mcp.resource("users://{user_id}/profile") +def get_user_profile(user_id: str, include_private: bool = False, format: str = "summary") -> str: + """Get user profile with optional private data and format selection. + + Example URIs: + - users://123/profile (public data, summary format) + - users://123/profile?include_private=true (includes private data) + - users://123/profile?format=detailed&include_private=true (detailed with private) + """ + from typing import Any + + profile_data: dict[str, Any] = {"user_id": user_id, "name": "John Doe", "public_bio": "Software developer"} + + if include_private: + profile_data.update({"email": "john@example.com", "phone": "+1234567890"}) + + if format == "detailed": + profile_data.update({"last_active": "2024-01-20", "preferences": {"notifications": True}}) + + return str(profile_data) + + +@mcp.resource("api://weather/{location}") +def get_weather_data( + location: str, units: str = "metric", lang: str = "en", include_forecast: bool = False, days: int = 5 +) -> str: + """Get weather data with customizable options. + + Example URIs: + - api://weather/london (basic weather) + - api://weather/london?units=imperial&lang=es (different units and language) + - api://weather/london?include_forecast=true&days=7 (with 7-day forecast) + """ + temp_unit = "C" if units == "metric" else "F" + base_temp = 22 if units == "metric" else 72 + + weather_info = f"Weather for {location}: {base_temp}{temp_unit}" + + if include_forecast: + weather_info += f"\n{days}-day forecast:" + for day in range(1, days + 1): + forecast_temp = base_temp + (day % 3) + weather_info += f"\nDay {day}: {forecast_temp}{temp_unit}" + + return weather_info + + +@mcp.resource("api://data/{user_id}/{region}/{city}/{file_path:path}") +def resource_fn( + # Path parameters + user_id: Annotated[int, Path(gt=0, description="User ID")], # explicit Path + region, # inferred path # type: ignore + city: str, # inferred path + file_path: str, # inferred path {file_path:path} + # Required query parameter (no default) + version: int, + # Optional query parameters (defaults or Query(...)) + format: Annotated[str, Query("json", description="Output format")], + include_metadata: bool = False, + tags: list[str] = [], + lang: str = "en", + debug: bool = False, + precision: float = 0.5, +) -> str: + return f"{user_id}/{region}/{city}/{file_path}" diff --git a/src/mcp/server/fastmcp/__init__.py b/src/mcp/server/fastmcp/__init__.py index a89902cfd..ba3cca0e8 100644 --- a/src/mcp/server/fastmcp/__init__.py +++ b/src/mcp/server/fastmcp/__init__.py @@ -5,7 +5,8 @@ from mcp.types import Icon from .server import Context, FastMCP +from .utilities.param_functions import Path, Query from .utilities.types import Audio, Image __version__ = version("mcp") -__all__ = ["FastMCP", "Context", "Image", "Audio", "Icon"] +__all__ = ["FastMCP", "Context", "Image", "Audio", "Icon", "Path", "Query"] diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index 3f02ebcba..bbf0a8c14 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -4,6 +4,7 @@ import inspect import re +import urllib.parse from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -11,7 +12,9 @@ from mcp.server.fastmcp.resources.types import FunctionResource, Resource from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context -from mcp.server.fastmcp.utilities.func_metadata import func_metadata +from mcp.server.fastmcp.utilities.convertors import Convertor +from mcp.server.fastmcp.utilities.func_metadata import func_metadata, use_defaults_on_optional_validation_error +from mcp.server.fastmcp.utilities.param_validation import validate_and_sync_params from mcp.types import Annotations, Icon if TYPE_CHECKING: @@ -23,7 +26,7 @@ class ResourceTemplate(BaseModel): """A template for dynamically creating resources.""" - uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)") + uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current{?units,format})") name: str = Field(description="Name of the resource") title: str | None = Field(description="Human-readable title of the resource", default=None) description: str | None = Field(description="Description of what the resource does") @@ -33,6 +36,24 @@ class ResourceTemplate(BaseModel): fn: Callable[..., Any] = Field(exclude=True) parameters: dict[str, Any] = Field(description="JSON schema for function parameters") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + compiled_pattern: re.Pattern[str] | None = Field( + default=None, description="Compiled regular expression pattern for matching the URI template." + ) + convertors: dict[str, Convertor[Any]] | None = Field( + default=None, description="Mapping of parameter names to their respective type converters." + ) + path_params: set[str] = Field( + default_factory=set, + description="Set of required parameters from the path component", + ) + required_query_params: set[str] = Field( + default_factory=set, + description="Set of required parameters specified in the query component", + ) + optional_query_params: set[str] = Field( + default_factory=set, + description="Set of optional parameters specified in the query component", + ) @classmethod def from_function( @@ -48,7 +69,8 @@ def from_function( context_kwarg: str | None = None, ) -> ResourceTemplate: """Create a template from a function.""" - func_name = name or fn.__name__ + original_fn = fn + func_name = name or original_fn.__name__ if func_name == "": raise ValueError("You must provide a name for lambda functions") @@ -63,44 +85,98 @@ def from_function( ) parameters = func_arg_metadata.arg_model.model_json_schema() - # ensure the arguments are properly cast - fn = validate_call(fn) + # First, apply pydantic's validation and coercion + validated_fn = validate_call(original_fn) + + # Then, apply our decorator to handle default fallback for optional params + final_fn = use_defaults_on_optional_validation_error(validated_fn) + + # Extract required and optional params from the original function's signature + (path_params, required_query_params, optional_query_params, convertors, compiled_pattern) = ( + validate_and_sync_params(original_fn, uri_template) + ) return cls( uri_template=uri_template, name=func_name, title=title, - description=description or fn.__doc__ or "", + description=description or original_fn.__doc__ or "", mime_type=mime_type or "text/plain", icons=icons, annotations=annotations, - fn=fn, + fn=final_fn, parameters=parameters, context_kwarg=context_kwarg, + path_params=path_params, + required_query_params=required_query_params, + optional_query_params=optional_query_params, + convertors=convertors, + compiled_pattern=compiled_pattern, ) def matches(self, uri: str) -> dict[str, Any] | None: """Check if URI matches template and extract parameters.""" - # Convert template to regex pattern - pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") - match = re.match(f"^{pattern}$", uri) - if match: - return match.groupdict() - return None + if not self.compiled_pattern or not self.convertors: + raise RuntimeError("Pattern did not compile for matching") + + # Split URI into path and query parts + if "?" in uri: + path, query = uri.split("?", 1) + else: + path, query = uri, "" + + match = self.compiled_pattern.match(path.strip("/")) + if not match: + return None + + params: dict[str, Any] = {} + + # ---- Extract and convert path parameters ---- + for name, conv in self.convertors.items(): + raw_value = match.group(name) + try: + params[name] = conv.convert(raw_value) + except Exception as e: + raise RuntimeError(f"Failed to convert '{raw_value}' for '{name}': {e}") + + # ---- Parse and merge query parameters ---- + query_dict = urllib.parse.parse_qs(query) if query else {} + + # Normalize and flatten query params + for key, values in query_dict.items(): + value = values[0] if values else None + if key in self.required_query_params or key in self.optional_query_params: + params[key] = value + + # ---- Validate required query parameters ---- + missing_required = [key for key in self.required_query_params if key not in params] + if missing_required: + raise ValueError(f"Missing required query parameters: {missing_required}") + + return params async def create_resource( self, uri: str, params: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, # type: ignore ) -> Resource: """Create a resource from the template with the given parameters.""" try: - # Add context to params if needed - params = inject_context(self.fn, params, context, self.context_kwarg) - - # Call function and check if result is a coroutine - result = self.fn(**params) + # Prepare parameters for function call + # For optional parameters not in URL, use their default values + # First add extracted parameters + fn_params = { + name: value + for name, value in params.items() + if name in self.path_params or name in self.required_query_params or name in self.optional_query_params + } + # Add context to params + fn_params = inject_context(self.fn, fn_params, context, self.context_kwarg) # type: ignore + # self.fn is now multiply-decorated: + # 1. validate_call for coercion/validation + # 2. our new decorator for default fallback on optional param validation err + result = self.fn(**fn_params) if inspect.iscoroutine(result): result = await result @@ -115,4 +191,6 @@ async def create_resource( fn=lambda: result, # Capture result in closure ) except Exception as e: + # This will catch errors from validate_call (e.g., for required params) + # or from our decorator if retry also fails, or any other errors. raise ValueError(f"Error creating resource from template: {e}") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 719595916..18f4298c9 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -3,7 +3,6 @@ from __future__ import annotations as _annotations import inspect -import re from collections.abc import ( AsyncIterator, Awaitable, @@ -48,7 +47,6 @@ from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager from mcp.server.fastmcp.tools import Tool, ToolManager -from mcp.server.fastmcp.utilities.context_injection import find_context_parameter from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel.server import LifespanResultT @@ -537,6 +535,15 @@ def resource( If the URI contains parameters (e.g. "resource://{param}") or the function has parameters, it will be registered as a template resource. + Function parameters in the path are required, + while parameters with default values + can be optionally provided as query parameters using RFC 6570 form-style query + expansion syntax: {?param1,param2,...} + + Examples: + - resource://{category}/{id}{?filter,sort,limit} + - resource://{user_id}/profile{?format,fields} + Args: uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") name: Optional name for the resource @@ -558,6 +565,19 @@ def get_data() -> str: def get_weather(city: str) -> str: return f"Weather for {city}" + @server.resource("resource://{city}/weather{?units}") + def get_weather_with_options(city: str, units: str = "metric") -> str: + # Can be called with resource://paris/weather?units=imperial + return f"Weather for {city} in {units} units" + + @server.resource("resource://{category}/{id} + {?filter,sort,limit}") + def get_item(category: str, id: str, filter: str = "all", sort: str = "name" + , limit: int = 10) -> str: + # Can be called with resource://electronics/1234?filter=new&sort=price&limit=20 + return f"Item {id} in {category}, filtered by {filter}, sorted by {sort} + , limited to {limit}" + @server.resource("resource://{city}/weather") async def get_weather(city: str) -> str: data = await fetch_weather(city) @@ -577,20 +597,6 @@ def decorator(fn: AnyFunction) -> AnyFunction: has_func_params = bool(sig.parameters) if has_uri_params or has_func_params: - # Check for Context parameter to exclude from validation - context_param = find_context_parameter(fn) - - # Validate that URI params match function params (excluding context) - uri_params = set(re.findall(r"{(\w+)}", uri)) - # We need to remove the context_param from the resource function if - # there is any. - func_params = {p for p in sig.parameters.keys() if p != context_param} - - if uri_params != func_params: - raise ValueError( - f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" - ) - # Register as template self._resource_manager.add_template( fn=fn, diff --git a/src/mcp/server/fastmcp/utilities/convertors.py b/src/mcp/server/fastmcp/utilities/convertors.py new file mode 100644 index 000000000..148c78a81 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/convertors.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import math +import uuid +from typing import Any, ClassVar, Generic, TypeVar, get_args + +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +T = TypeVar("T") + + +class Convertor(Generic[T]): + regex: ClassVar[str] = "" + python_type: Any = Any # type hint for runtime type + + def __init_subclass__(cls, **kwargs: dict[str, Any]) -> None: + super().__init_subclass__(**kwargs) + # Extract the concrete type from the generic base + base = cls.__orig_bases__[0] # type: ignore[attr-defined] + args = get_args(base) + if args: + cls.python_type = args[0] # type: ignore[assignment] + else: + raise RuntimeError(f"Bad converter definition in class {cls.__name__}") + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): + return core_schema.any_schema() + + def convert(self, value: str) -> T: + raise NotImplementedError() + + def to_string(self, value: T) -> str: + raise NotImplementedError() + + +class StringConvertor(Convertor[str]): + regex = r"[^/]+" + + def convert(self, value: str) -> str: + return value + + def to_string(self, value: str) -> str: + value = str(value) + assert "/" not in value, "May not contain path separators" + assert value, "Must not be empty" + return value + + +class PathConvertor(Convertor[str]): + regex = r".*" + + def convert(self, value: str) -> str: + return str(value) + + def to_string(self, value: str) -> str: + return str(value) + + +class IntegerConvertor(Convertor[int]): + regex = r"[0-9]+" + + def convert(self, value: str) -> int: + try: + return int(value) + except ValueError: + raise ValueError(f"Value '{value}' is not a valid integer") + + def to_string(self, value: int) -> str: + value = int(value) + assert value >= 0, "Negative integers are not supported" + return str(value) + + +class FloatConvertor(Convertor[float]): + regex = r"[0-9]+(?:\.[0-9]+)?" + + def convert(self, value: str) -> float: + try: + return float(value) + except ValueError: + raise ValueError(f"Value '{value}' is not a valid float") + + def to_string(self, value: float) -> str: + value = float(value) + assert value >= 0.0, "Negative floats are not supported" + assert not math.isnan(value), "NaN values are not supported" + assert not math.isinf(value), "Infinite values are not supported" + return f"{value:.20f}".rstrip("0").rstrip(".") + + +class UUIDConvertor(Convertor[uuid.UUID]): + regex = r"[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}" + + def convert(self, value: str) -> uuid.UUID: + try: + return uuid.UUID(value) + except ValueError: + raise ValueError(f"Value '{value}' is not a valid UUID") + + def to_string(self, value: uuid.UUID) -> str: + return str(value) + + +CONVERTOR_TYPES: dict[str, Convertor[Any]] = { + "str": StringConvertor(), + "path": PathConvertor(), + "int": IntegerConvertor(), + "float": FloatConvertor(), + "uuid": UUIDConvertor(), +} diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 873b1ae19..04dc8bb49 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -1,3 +1,4 @@ +import functools import inspect import json import types @@ -12,6 +13,7 @@ ConfigDict, Field, RootModel, + ValidationError, WithJsonSchema, create_model, ) @@ -548,3 +550,110 @@ def _convert_to_content( result = pydantic_core.to_json(result, fallback=str, indent=2).decode() return [TextContent(type="text", text=result)] + + +def use_defaults_on_optional_validation_error( + decorated_fn: Callable[..., Any], +) -> Callable[..., Any]: + """ + Decorator for a function already wrapped by pydantic.validate_call. + If the wrapped function call fails due to a ValidationError, this decorator + checks if the error was caused by an optional parameter. If so, it retries + the call, explicitly omitting the failing optional parameter(s) to allow + Pydantic/the function to use their default values. + + If the error is for a required parameter, or if the retry fails, the original + error is re-raised. + """ + # Get the original function's signature (before validate_call) to inspect defaults + original_fn = inspect.unwrap(decorated_fn) + original_sig = inspect.signature(original_fn) + optional_params_with_defaults = { + name: param.default + for name, param in original_sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + @functools.wraps(decorated_fn) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return await decorated_fn(*args, **kwargs) + except ValidationError as e: + # Check if the validation error is solely for optional parameters + failing_optional_params_to_retry: dict[str, bool] = {} + failing_required_params: list[str] = [] # Explicitly typed + + for error in e.errors(): + # error['loc'] is a tuple, e.g., ('param_name',) + # Pydantic error locations are tuples of strings or ints. + # For field errors, the first element is the field name (str). + if error["loc"] and isinstance(error["loc"][0], str): + param_name: str = error["loc"][0] + if param_name in optional_params_with_defaults: + # It's an optional param that failed. Mark for retry by exclude. + failing_optional_params_to_retry[param_name] = True + else: + # It's a required parameter or a non-parameter error + failing_required_params.append(param_name) + else: # Non-parameter specific error or unexpected error structure + raise e + + if failing_required_params or not failing_optional_params_to_retry: + # re-raise if any req params failed, or if no opt params were identified + logger.debug( + f"Validation failed for required params or no optional params " + f"identified. Re-raising original error for {original_fn.__name__}." + ) + raise e + + # At this point, only optional parameters caused the ValidationError. + # Retry the call, removing the failing optional params from kwargs. + # This allows validate_call/the function to use their defaults. + new_kwargs = {k: v for k, v in kwargs.items() if k not in failing_optional_params_to_retry} + + # Preserve positional arguments + # failing_optional_params_to_retry.keys() is a KeysView[str] + # list(KeysView[str]) is list[str] + logger.info( + f"Retrying {original_fn.__name__} with default values" + f"for optional params: {list(failing_optional_params_to_retry.keys())}" + ) + return await decorated_fn(*args, **new_kwargs) + + @functools.wraps(decorated_fn) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return decorated_fn(*args, **kwargs) + except ValidationError as e: + failing_optional_params_to_retry: dict[str, bool] = {} + failing_required_params: list[str] = [] # Explicitly typed + + for error in e.errors(): + if error["loc"] and isinstance(error["loc"][0], str): + param_name: str = error["loc"][0] + if param_name in optional_params_with_defaults: + failing_optional_params_to_retry[param_name] = True + else: + failing_required_params.append(param_name) + else: + raise e + + if failing_required_params or not failing_optional_params_to_retry: + logger.debug( + f"Validation failed for required params or no optional params " + f"identified. Re-raising original error for {original_fn.__name__}." + ) + raise e + + new_kwargs = {k: v for k, v in kwargs.items() if k not in failing_optional_params_to_retry} + logger.info( + f"Retrying {original_fn.__name__} with default values" + f"for optional params: {list(failing_optional_params_to_retry.keys())}" + ) + return decorated_fn(*args, **new_kwargs) + + if inspect.iscoroutinefunction( + original_fn + ): # Check original_fn because decorated_fn might be a partial or already wrapped + return async_wrapper + return sync_wrapper diff --git a/src/mcp/server/fastmcp/utilities/param_functions.py b/src/mcp/server/fastmcp/utilities/param_functions.py new file mode 100644 index 000000000..031878774 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/param_functions.py @@ -0,0 +1,502 @@ +from collections.abc import Callable +from typing import Annotated, Any + +from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Doc + +from mcp.server.fastmcp.utilities import params + +PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) +PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 + +if not PYDANTIC_V2: + from pydantic.fields import Undefined # type: ignore[attr-defined] +else: + from pydantic.v1.fields import Undefined + +# difference between not given not needed, not given maybe needed. +_Unset: Any = Undefined # type: ignore + + +def Path( # noqa: PLR0913 + default: Annotated[ + Any, + Doc( + """ + Default value if the parameter field is not set. + + This doesn't affect `Path` parameters as the value is always required. + The parameter is available only for compatibility. + """ + ), + ] = ..., + *, + default_factory: Annotated[ + Callable[[], Any] | None, + Doc( + """ + A callable to generate the default value. + + This doesn't affect `Path` parameters as the value is always required. + The parameter is available only for compatibility. + """ + ), + ] = _Unset, + alias: Annotated[ + str | None, + Doc( + """ + An alternative name for the parameter field. + + This will be used to extract the data and for the generated OpenAPI. + It is particularly useful when you can't use the name you want because it + is a Python reserved keyword or similar. + """ + ), + ] = None, + alias_priority: Annotated[ + int | None, + Doc( + """ + Priority of the alias. This affects whether an alias generator is used. + """ + ), + ] = None, + validation_alias: Annotated[ + str | None, + Doc( + """ + 'Whitelist' validation step. The parameter field will be the single one + allowed by the alias or set of aliases defined. + """ + ), + ] = None, + serialization_alias: Annotated[ + str | None, + Doc( + """ + 'Blacklist' validation step. The vanilla parameter field will be the + single one of the alias' or set of aliases' fields and all the other + fields will be ignored at serialization time. + """ + ), + ] = None, + title: Annotated[ + str | None, + Doc( + """ + Human-readable title. + """ + ), + ] = None, + description: Annotated[ + str | None, + Doc( + """ + Human-readable description. + """ + ), + ] = None, + gt: Annotated[ + float | None, + Doc( + """ + Greater than. If set, value must be greater than this. Only applicable to + numbers. + """ + ), + ] = None, + ge: Annotated[ + float | None, + Doc( + """ + Greater than or equal. If set, value must be greater than or equal to + this. Only applicable to numbers. + """ + ), + ] = None, + lt: Annotated[ + float | None, + Doc( + """ + Less than. If set, value must be less than this. Only applicable to numbers. + """ + ), + ] = None, + le: Annotated[ + float | None, + Doc( + """ + Less than or equal. If set, value must be less than or equal to this. + Only applicable to numbers. + """ + ), + ] = None, + min_length: Annotated[ + int | None, + Doc( + """ + Minimum length for strings. + """ + ), + ] = None, + max_length: Annotated[ + int | None, + Doc( + """ + Maximum length for strings. + """ + ), + ] = None, + pattern: Annotated[ + str | None, + Doc( + """ + RegEx pattern for strings. + """ + ), + ] = None, + discriminator: Annotated[ + str | None, + Doc( + """ + Parameter field name for discriminating the type in a tagged union. + """ + ), + ] = None, + strict: Annotated[ + bool | None, + Doc( + """ + If `True`, strict validation is applied to the field. + """ + ), + ] = None, + multiple_of: Annotated[ + float | None, + Doc( + """ + Value must be a multiple of this. Only applicable to numbers. + """ + ), + ] = None, + allow_inf_nan: Annotated[ + bool | None, + Doc( + """ + Allow `inf`, `-inf`, `nan`. Only applicable to numbers. + """ + ), + ] = None, + max_digits: Annotated[ + int | None, + Doc( + """ + Maximum number of allow digits for strings. + """ + ), + ] = None, + decimal_places: Annotated[ + int | None, + Doc( + """ + Maximum number of decimal places allowed for numbers. + """ + ), + ] = None, + examples: Annotated[ + list[Any] | None, + Doc( + """ + Example values for this field. + """ + ), + ] = None, + include_in_schema: Annotated[ + bool, + Doc( + """ + To include (or not) this parameter field in the generated OpenAPI. + You probably don't need it, but it's available. + + This affects the generated OpenAPI (e.g. visible at `/docs`). + """ + ), + ] = True, + json_schema_extra: Annotated[ + dict[str, Any] | None, + Doc( + """ + Any additional JSON schema data. + """ + ), + ] = None, +) -> Any: + """ + Declare a path parameter for a *path operation*. + """ + return params.Path( + default=default, + default_factory=default_factory, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + ) + + +def Query( # noqa: PLR0913 + default: Annotated[ + Any, + Doc( + """ + Default value if the parameter field is not set. + """ + ), + ] = Undefined, + *, + default_factory: Annotated[ + Callable[[], Any] | None, + Doc( + """ + A callable to generate the default value. + + This doesn't affect `Path` parameters as the value is always required. + The parameter is available only for compatibility. + """ + ), + ] = _Unset, + alias: Annotated[ + str | None, + Doc( + """ + An alternative name for the parameter field. + + This will be used to extract the data and for the generated OpenAPI. + It is particularly useful when you can't use the name you want because it + is a Python reserved keyword or similar. + """ + ), + ] = None, + alias_priority: Annotated[ + int | None, + Doc( + """ + Priority of the alias. This affects whether an alias generator is used. + """ + ), + ] = _Unset, + validation_alias: Annotated[ + str | None, + Doc( + """ + 'Whitelist' validation step. The parameter field will be the single one + allowed by the alias or set of aliases defined. + """ + ), + ] = None, + serialization_alias: Annotated[ + str | None, + Doc( + """ + 'Blacklist' validation step. The vanilla parameter field will be the + single one of the alias' or set of aliases' fields and all the other + fields will be ignored at serialization time. + """ + ), + ] = None, + title: Annotated[ + str | None, + Doc( + """ + Human-readable title. + """ + ), + ] = None, + description: Annotated[ + str | None, + Doc( + """ + Human-readable description. + """ + ), + ] = None, + gt: Annotated[ + float | None, + Doc( + """ + Greater than. If set, value must be greater than this. Only applicable to + numbers. + """ + ), + ] = None, + ge: Annotated[ + float | None, + Doc( + """ + Greater than or equal. If set, value must be greater than or equal to + this. Only applicable to numbers. + """ + ), + ] = None, + lt: Annotated[ + float | None, + Doc( + """ + Less than. If set, value must be less than this. Only applicable to numbers. + """ + ), + ] = None, + le: Annotated[ + float | None, + Doc( + """ + Less than or equal. If set, value must be less than or equal to this. + Only applicable to numbers. + """ + ), + ] = None, + min_length: Annotated[ + int | None, + Doc( + """ + Minimum length for strings. + """ + ), + ] = None, + max_length: Annotated[ + int | None, + Doc( + """ + Maximum length for strings. + """ + ), + ] = None, + pattern: Annotated[ + str | None, + Doc( + """ + RegEx pattern for strings. + """ + ), + ] = None, + discriminator: Annotated[ + str | None, + Doc( + """ + Parameter field name for discriminating the type in a tagged union. + """ + ), + ] = None, + strict: Annotated[ + bool | None, + Doc( + """ + If `True`, strict validation is applied to the field. + """ + ), + ] = _Unset, + multiple_of: Annotated[ + float | None, + Doc( + """ + Value must be a multiple of this. Only applicable to numbers. + """ + ), + ] = _Unset, + allow_inf_nan: Annotated[ + bool | None, + Doc( + """ + Allow `inf`, `-inf`, `nan`. Only applicable to numbers. + """ + ), + ] = _Unset, + max_digits: Annotated[ + int | None, + Doc( + """ + Maximum number of allow digits for strings. + """ + ), + ] = _Unset, + decimal_places: Annotated[ + int | None, + Doc( + """ + Maximum number of decimal places allowed for numbers. + """ + ), + ] = _Unset, + examples: Annotated[ + list[Any] | None, + Doc( + """ + Example values for this field. + """ + ), + ] = None, + include_in_schema: Annotated[ + bool, + Doc( + """ + To include (or not) this parameter field in the generated OpenAPI. + You probably don't need it, but it's available. + + This affects the generated OpenAPI (e.g. visible at `/docs`). + """ + ), + ] = True, + json_schema_extra: Annotated[ + dict[str, Any] | None, + Doc( + """ + Any additional JSON schema data. + """ + ), + ] = None, +) -> Any: + return params.Query( + default=default, + default_factory=default_factory, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + ) diff --git a/src/mcp/server/fastmcp/utilities/param_validation.py b/src/mcp/server/fastmcp/utilities/param_validation.py new file mode 100644 index 000000000..1df135965 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/param_validation.py @@ -0,0 +1,207 @@ +"""Utility functions for validating and aligning function parameters with URI templates.""" + +from __future__ import annotations + +import inspect +import re +from collections.abc import Callable +from typing import Annotated, Any, get_args, get_origin + +from pydantic.version import VERSION as PYDANTIC_VERSION + +from mcp.server.fastmcp.utilities.convertors import CONVERTOR_TYPES, Convertor +from mcp.server.fastmcp.utilities.params import Path, Query + +PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) +PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 + +if not PYDANTIC_V2: + from pydantic.fields import Undefined # type: ignore[attr-defined] +else: + from pydantic.v1.fields import Undefined + +# difference between not given not needed, not given maybe needed. +_Unset: Any = Undefined # type: ignore + + +def validate_and_sync_params( + fn: Callable[..., Any], + uri_template: str, +) -> tuple[ + set[str], # path_params + set[str], # required_query_params + set[str], # optional_query_params + dict[str, Convertor[Any]], + re.Pattern[str], +]: + """ + Analyze a function signature and URI template to: + - Collect parameter types from the function + - Validate and align URI parameters with the function signature + - Infer or validate types between both sides + - Build regex pattern and converters + + Returns: + (path_params, required_query_params, optional_query_params, converters, compiled_pattern, fn_defaults) + """ + fn_param_types, explicit_path_params, fn_defaults, ignore_query_params = _extract_function_params(fn) + parts = uri_template.strip("/").split("/") + + uri_pattern, converters, path_params = _parse_uri_and_validate_types(fn.__name__, parts, fn_param_types) + + path_params, required_query, optional_query = _postprocess_params( + ignore_query_params, fn_param_types, explicit_path_params, fn_defaults, path_params, parts + ) + + return path_params, required_query, optional_query, converters, uri_pattern + + +def _extract_function_params( + fn: Callable[..., Any], +) -> tuple[dict[str, Any], set[str], dict[str, Any], set[str]]: + """ + Extract parameter types and defaults from the function signature. + Detect explicitly annotated Path parameters (Annotated[..., Path()]). + """ + IGNORE_TYPES: set[str] = {"Context"} + sig = inspect.signature(fn) + fn_param_types: dict[str, Any] = {} + explicit_path: set[str] = set() + fn_defaults: dict[str, Any] = {} + ignore_query_params: set[str] = set() + for name, param in sig.parameters.items(): + base_type = param.annotation + if get_origin(param.annotation) is Annotated: + args = get_args(param.annotation) + if args: + base_type = args[0] + for meta in args[1:]: + if isinstance(meta, Path): + explicit_path.add(name) + if isinstance(meta, Query): + if meta.default is not Undefined: + fn_defaults[name] = meta.default + + fn_param_types[name] = base_type + + # IGNORE TYPES caused a circular import with Context so using it as a string + # defaults are optional query, path cannot have defaults + # In a weird way have to do this to get base(origin) for Context + def get_base(name: str) -> str: + return name.split("[")[0] + + base = get_base(param.annotation.__name__) + if base in IGNORE_TYPES: + ignore_query_params.add(name) + if param.default is not inspect._empty: # type: ignore + fn_defaults[name] = param.default + return fn_param_types, explicit_path, fn_defaults, ignore_query_params + + +def _parse_uri_and_validate_types( + fn_name: str, parts: list[str], fn_param_types: dict[str, Any] +) -> tuple[re.Pattern[str], dict[str, Convertor[Any]], set[str]]: + """ + Parse URI path components, infer or validate types, and build converters. + Returns a compiled regex pattern, converter mapping, and detected path parameters. + """ + pattern_parts: list[str] = [] + converters: dict[str, Convertor[Any]] = {} + path_params: set[str] = set() + + for part in parts: + match = re.fullmatch(r"\{(\w+)(?::(\w+))?\}", part) + if not match: + pattern_parts.append(re.escape(part)) + continue + + name, uri_type = match.groups() + if name not in fn_param_types: + raise ValueError( + f"Mismatch between URI path parameters '{name}' and required function parameters in '{fn_name}'" + ) + + fn_type = fn_param_types[name] + uri_type = _resolve_type(name, uri_type, fn_type) + + if uri_type not in CONVERTOR_TYPES: + raise NotImplementedError(f"Parameter '{name}' in URI uses unsupported type '{uri_type}'.") + + conv = CONVERTOR_TYPES[uri_type] + converters[name] = conv + pattern_parts.append(f"(?P<{name}>{conv.regex})") + path_params.add(name) + + uri_pattern = re.compile("^" + "/".join(pattern_parts) + "$") + return uri_pattern, converters, path_params + + +def _resolve_type(name: str, uri_type: str | None, fn_type: Any) -> str: + """ + Infer or validate type consistency between URI and function parameter. + - If only one side defines a type, the other inherits it. + - If both define types, they must be compatible. + """ + if uri_type and uri_type not in CONVERTOR_TYPES: + raise NotImplementedError(f"Unknown converter type '{uri_type}' in URI template") + + if uri_type is None: + if fn_type is not inspect._empty: # type: ignore + tname = getattr(fn_type, "__name__", None) + if tname in CONVERTOR_TYPES: + return tname + return "str" + + if fn_type is not inspect._empty and uri_type in CONVERTOR_TYPES: # type: ignore + expected_type = CONVERTOR_TYPES[uri_type].python_type + if fn_type != expected_type and not issubclass(fn_type, expected_type): + raise TypeError( + f"Type mismatch for '{name}': URI declares {expected_type.__name__}, " + f"function declares {getattr(fn_type, '__name__', fn_type)}" + ) + + return uri_type + + +def _postprocess_params( + ignore_query_params: set[str], + fn_param_types: dict[str, Any], + explicit_path: set[str], + fn_defaults: dict[str, Any], + path_params: set[str], + parts: list[str], +) -> tuple[set[str], set[str], set[str]]: + """ + Perform final validation and classification of parameters: + - Ensure 'path' converters only appear as the last URI segment + - Ensure explicitly declared Path parameters exist in the URI + - Derive query parameters as required or optional based on defaults + """ + # Validate 'path' types appear last + for i, part in enumerate(parts): + match = re.fullmatch(r"\{(\w+)(?::(\w+))?\}", part) + if not match: + continue + _, uri_type = match.groups() + if uri_type == "path" and i != len(parts) - 1: + raise ValueError("Path parameters must appear last in the URI template") + + # Ensure explicit Path() parameters exist in URI + missing = explicit_path - path_params + if missing: + raise ValueError(f"Explicit Path parameters {missing} are not present in URI template") + + # Ensure path parameters dont have defaults. + if not path_params.isdisjoint(fn_defaults.keys()): + raise ValueError("Path parameters cannot have defaults.") + + # Everything not in path_params and ingore_query_params is a query param + def is_query_param(name: str) -> bool: + return name not in path_params and name not in ignore_query_params + + query_params = {name for name in fn_param_types if is_query_param(name)} + + required_query = {n for n in query_params if n not in fn_defaults} + optional_query = {n for n in query_params if n in fn_defaults} + + return path_params, required_query, optional_query diff --git a/src/mcp/server/fastmcp/utilities/params.py b/src/mcp/server/fastmcp/utilities/params.py new file mode 100644 index 000000000..7e51d3272 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/params.py @@ -0,0 +1,223 @@ +from collections.abc import Callable +from enum import Enum +from typing import Any + +from pydantic.fields import FieldInfo +from pydantic.version import VERSION as PYDANTIC_VERSION + +PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) +PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 + +if not PYDANTIC_V2: + from pydantic.fields import Undefined # type: ignore[attr-defined] +else: + from pydantic.v1.fields import Undefined + +# difference between not given not needed, not given maybe needed. +_Unset: Any = Undefined # type: ignore + + +class ParamTypes(Enum): + query = "query" + path = "path" + + +class Param(FieldInfo): # type: ignore[misc] + in_: ParamTypes + + def __init__( # noqa: PLR0913 + self, + default: Any = Undefined, + *, + default_factory: Callable[[], Any] | None = _Unset, + annotation: Any | None = None, + alias: str | None = None, + alias_priority: int | None = _Unset, + validation_alias: str | None = None, + serialization_alias: str | None = None, + title: str | None = None, + description: str | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + discriminator: str | None = None, + strict: bool | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + examples: list[Any] | None = None, + include_in_schema: bool = True, + json_schema_extra: dict[str, Any] | None = None, + ): + self.include_in_schema = include_in_schema + kwargs = { + "default": default, + "default_factory": default_factory, + "alias": alias, + "title": title, + "description": description, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "min_length": min_length, + "max_length": max_length, + "discriminator": discriminator, + "multiple_of": multiple_of, + "allow_inf_nan": allow_inf_nan, + "max_digits": max_digits, + "decimal_places": decimal_places, + } + if examples is not None: + kwargs["examples"] = examples + current_json_schema_extra = json_schema_extra + if PYDANTIC_V2: + kwargs.update( + { + "annotation": annotation, + "alias_priority": alias_priority, + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + "strict": strict, + "json_schema_extra": current_json_schema_extra, + } + ) + kwargs["pattern"] = pattern + else: + kwargs["regex"] = pattern + kwargs.update(**current_json_schema_extra) # type: ignore + use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} + + super().__init__(**use_kwargs) # type: ignore + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.default})" + + +class Path(Param): # type: ignore[misc] + in_ = ParamTypes.path + + def __init__( # noqa: PLR0913 + self, + default: Any = ..., + *, + default_factory: Callable[[], Any] | None = _Unset, + annotation: Any | None = None, + alias: str | None = None, + alias_priority: int | None = _Unset, + validation_alias: str | None = None, + serialization_alias: str | None = None, + title: str | None = None, + description: str | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + discriminator: str | None = None, + strict: bool | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + examples: list[Any] | None = None, + include_in_schema: bool = True, + json_schema_extra: dict[str, Any] | None = None, + ): + assert default is ..., "Path parameters cannot have a default value" + self.in_ = self.in_ + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + ) + + +class Query(Param): # type: ignore[misc] + in_ = ParamTypes.query + + def __init__( # noqa: PLR0913 + self, + default: Any = Undefined, + *, + default_factory: Callable[[], Any] | None = _Unset, + annotation: Any | None = None, + alias: str | None = None, + alias_priority: int | None = _Unset, + validation_alias: str | None = None, + serialization_alias: str | None = None, + title: str | None = None, + description: str | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + discriminator: str | None = None, + strict: bool | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + examples: list[Any] | None = None, + include_in_schema: bool = True, + json_schema_extra: dict[str, Any] | None = None, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + ) diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index 6dccec84d..34618356b 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -5,9 +5,9 @@ pytestmark = pytest.mark.anyio +@pytest.mark.anyio async def test_list_tools_returns_all_tools(): mcp = FastMCP("TestTools") - # Create 100 tools with unique names num_tools = 100 for i in range(num_tools): diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index 3145f65e8..19fadf49a 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -1,3 +1,5 @@ +import json + import pytest from pydantic import AnyUrl @@ -21,34 +23,46 @@ async def test_resource_template_edge_cases(): def get_user_post(user_id: str, post_id: str) -> str: return f"Post {post_id} by user {user_id}" - # Test case 2: Template with optional parameter (should fail) - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://users/{user_id}/profile") - def get_user_profile(user_id: str, optional_param: str | None = None) -> str: - return f"Profile for user {user_id}" + # Test case 2: Template with valid optional parameters + # using form-style query expansion + @mcp.resource("resource://users/{user_id}/profile") + def get_user_profile(user_id: str, format: str = "json", fields: str = "basic") -> str: + return f"Profile for user {user_id} in {format} format with fields: {fields}" # Test case 3: Template with mismatched parameters - with pytest.raises(ValueError, match="Mismatch between URI parameters"): + with pytest.raises( + ValueError, + match="Mismatch between URI path parameters .* and required function parameters .*", + ): @mcp.resource("resource://users/{user_id}/profile") def get_user_profile_mismatch(different_param: str) -> str: return f"Profile for user {different_param}" - # Test case 4: Template with extra function parameters - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://users/{user_id}/profile") - def get_user_profile_extra(user_id: str, extra_param: str) -> str: - return f"Profile for user {user_id}" - # Test case 5: Template with missing function parameters - with pytest.raises(ValueError, match="Mismatch between URI parameters"): + with pytest.raises( + ValueError, + match="Mismatch between URI path parameters .* and required function parameters .*", + ): @mcp.resource("resource://users/{user_id}/profile/{section}") def get_user_profile_missing(user_id: str) -> str: return f"Profile for user {user_id}" + # Test case 7: Make sure the resource with form-style query parameters works + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("resource://users/123/profile")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Profile for user 123 in json format with fields: basic" + + result = await client.read_resource(AnyUrl("resource://users/123/profile?format=xml")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Profile for user 123 in xml format with fields: basic" + + result = await client.read_resource(AnyUrl("resource://users/123/profile?format=xml&fields=detailed")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Profile for user 123 in xml format with fields: detailed" + # Verify valid template works result = await mcp.read_resource("resource://users/123/posts/456") result_list = list(result) @@ -112,3 +126,114 @@ def get_user_profile(user_id: str) -> str: with pytest.raises(Exception): # Specific exception type may vary await session.read_resource(AnyUrl("resource://users/123/invalid")) # Invalid template + + +@pytest.mark.anyio +async def test_resource_template_optional_param_default_fallback_e2e(): + """Test end-to-end that optional params fallback to defaults on validation error.""" + mcp = FastMCP("FallbackDemo") + + @mcp.resource("resource://config/{section}") + def get_config( + section: str, + theme: str = "dark", + timeout: int = 30, + is_feature_enabled: bool = False, + ) -> dict[str, str | int | bool]: + return { + "section": section, + "theme": theme, + "timeout": timeout, + "is_feature_enabled": is_feature_enabled, + } + + async with client_session(mcp._mcp_server) as client: + await client.initialize() + + # 1. All defaults for optional params + uri1 = "resource://config/network" + res1 = await client.read_resource(AnyUrl(uri1)) + assert res1.contents and isinstance(res1.contents[0], TextResourceContents) + data1 = json.loads(res1.contents[0].text) + assert data1 == { + "section": "network", + "theme": "dark", + "timeout": 30, + "is_feature_enabled": False, + } + + # 2. Valid optional params (theme is URL encoded, timeout is valid int string) + uri2 = "resource://config/ui?theme=light%20blue&timeout=60&is_feature_enabled=true" + res2 = await client.read_resource(AnyUrl(uri2)) + assert res2.contents and isinstance(res2.contents[0], TextResourceContents) + data2 = json.loads(res2.contents[0].text) + assert data2 == { + "section": "ui", + "theme": "light blue", + "timeout": 60, + "is_feature_enabled": True, + } + + # 3.Invalid 'timeout'(optional int),valid 'theme','is_feature_enabled' not given + # timeout=abc should use default 30 + uri3 = "resource://config/storage?theme=grayscale&timeout=abc" + res3 = await client.read_resource(AnyUrl(uri3)) + assert res3.contents and isinstance(res3.contents[0], TextResourceContents) + data3 = json.loads(res3.contents[0].text) + assert data3 == { + "section": "storage", + "theme": "grayscale", + "timeout": 30, # Fallback to default + "is_feature_enabled": False, # Fallback to default + } + + # 4.Invalid 'is_feature_enabled'(optional bool),'timeout'valid,'theme' not given + # is_feature_enabled=notbool should use default False + uri4 = "resource://config/user?timeout=15&is_feature_enabled=notbool" + res4 = await client.read_resource(AnyUrl(uri4)) + assert res4.contents and isinstance(res4.contents[0], TextResourceContents) + data4 = json.loads(res4.contents[0].text) + assert data4 == { + "section": "user", + "theme": "dark", # Fallback to default + "timeout": 15, + "is_feature_enabled": False, # Fallback to default + } + + # 5. Empty value for optional 'theme' (string type) + uri5 = "resource://config/general?theme=" + res5 = await client.read_resource(AnyUrl(uri5)) + assert res5.contents and isinstance(res5.contents[0], TextResourceContents) + data5 = json.loads(res5.contents[0].text) + assert data5 == { + "section": "general", + "theme": "dark", # Fallback to default because param is removed by parse_qs + "timeout": 30, + "is_feature_enabled": False, + } + + # 6. Empty value for optional 'timeout' (int type) + # timeout= (empty value) should fall back to default + uri6 = "resource://config/advanced?timeout=" + res6 = await client.read_resource(AnyUrl(uri6)) + assert res6.contents and isinstance(res6.contents[0], TextResourceContents) + data6 = json.loads(res6.contents[0].text) + assert data6 == { + "section": "advanced", + "theme": "dark", + "timeout": 30, # Fallback to default because param is removed by parse_qs + "is_feature_enabled": False, + } + + # 7. Invalid required path param type + # This scenario is more about the FastMCP.read_resource and its error handling + @mcp.resource("resource://item/{item_code}/check") # item_code is string here + def check_item(item_code: int) -> dict[str, str | bool]: # but int in function + return {"item_code_type": str(type(item_code)), "valid_code": item_code > 0} + + uri7 = "resource://item/notaninteger/check" + # specific exception may vary + with pytest.raises(Exception): + # The err is caught by FastMCP.read_resource and re-raised as ResourceError, + # which the client sees as a general McpError or similar. + await client.read_resource(AnyUrl(uri7)) diff --git a/tests/server/fastmcp/resources/test_resource_template.py b/tests/server/fastmcp/resources/test_resource_template.py index 8224d04b1..6c16fdeff 100644 --- a/tests/server/fastmcp/resources/test_resource_template.py +++ b/tests/server/fastmcp/resources/test_resource_template.py @@ -1,10 +1,10 @@ import json -from typing import Any +from typing import Annotated, Any import pytest from pydantic import BaseModel -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import FastMCP, Path, Query from mcp.server.fastmcp.resources import FunctionResource, ResourceTemplate from mcp.types import Annotations @@ -42,12 +42,93 @@ def my_func(key: str, value: int) -> dict[str, Any]: # Valid match params = template.matches("test://foo/123") - assert params == {"key": "foo", "value": "123"} + assert params == {"key": "foo", "value": 123} # No match assert template.matches("test://foo") is None assert template.matches("other://foo/123") is None + def test_template_matches_with_types(self): + """Test matching URIs with typed placeholders.""" + + def my_func(a: int, b: float, name: str) -> dict[str, Any]: + return {"a": a, "b": b, "name": name} + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="calc://{a:int}/{b:float}/{name:str}", + name="calc", + ) + + params = template.matches("calc://10/3.14/foo") + + assert params == {"a": 10, "b": 3.14, "name": "foo"} + assert template.matches("calc://x/3.14/foo") is None + assert template.matches("calc://10/bar/foo") is None + + def test_template_matches_with_path(self): + """Test matching URIs with {path:path} placeholder.""" + + def my_func(path: str) -> str: + return path + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="files://{path:path}", + name="file", + ) + + params = template.matches("files://foo/bar/baz.txt") + assert params == {"path": "foo/bar/baz.txt"} + assert template.matches("wrong://foo/bar") is None + + def test_template_with_optional_parameters(self): + """Test templates with optional parameters via query string.""" + + def my_func(key: str, sort: str = "asc", limit: int = 10) -> dict[str, str | int]: + return {"key": key, "sort": sort, "limit": limit} + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="test://{key}", + name="test", + ) + + # Verify required/optional params + assert template.path_params == {"key"} + assert template.optional_query_params == {"sort", "limit"} + + # Match with no query params - should only extract path param + params = template.matches("test://foo") + assert params == {"key": "foo"} + + # Match with query params + params = template.matches("test://foo?sort=desc&limit=20") + assert params == {"key": "foo", "sort": "desc", "limit": "20"} + + # Match with partial query params + params = template.matches("test://foo?sort=desc") + assert params == {"key": "foo", "sort": "desc"} + + # Match with unknown query params - should ignore + params = template.matches("test://foo?unknown=value") + assert params == {"key": "foo"} + + def test_template_validation(self): + """Test template validation with required/optional parameters.""" + + # Valid: required param in path + def valid_func(key: str, optional: str = "default") -> str: + return f"{key}-{optional}" + + template = ResourceTemplate.from_function( + fn=valid_func, + uri_template="test://{key}", + name="test", + ) + assert template.path_params == {"key"} + assert template.optional_query_params == {"optional"} + @pytest.mark.anyio async def test_create_resource(self): """Test creating a resource from a template.""" @@ -189,6 +270,273 @@ def get_data(value: str) -> CustomData: content = await resource.read() assert content == '"hello"' + @pytest.mark.anyio + async def test_create_resource_with_optional_params(self): + """Test creating resources with optional parameters.""" + + def my_func(key: str, sort: str = "asc", limit: int = 10) -> dict[str, str | int]: + return {"key": key, "sort": sort, "limit": limit} + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="test://{key}", + name="test", + ) + + # Create with only required params + params = {"key": "foo"} + resource = await template.create_resource("test://foo", params) + result = await resource.read() + assert isinstance(result, str) + assert json.loads(result) == {"key": "foo", "sort": "asc", "limit": 10} + + # Create with all params + params = {"key": "foo", "sort": "desc", "limit": "20"} + resource = await template.create_resource("test://foo?sort=desc&limit=20", params) + result = await resource.read() + assert isinstance(result, str) + assert json.loads(result) == {"key": "foo", "sort": "desc", "limit": 20} + + def test_template_with_form_style_query_expansion(self): + """Test templates with RFC 6570 form-style query expansion.""" + + def my_func( + category: str, + id: str, + filter: str = "all", + sort: str = "name", + limit: int = 10, + ) -> dict[str, str | int]: + return { + "category": category, + "id": id, + "filter": filter, + "sort": sort, + "limit": limit, + } + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="test://{category}/{id}", + name="test", + ) + + # Verify required/optional params + assert template.path_params == {"category", "id"} + assert template.optional_query_params == {"filter", "sort", "limit"} + + # Match with no query params - should only extract path params + params = template.matches("test://electronics/1234") + assert params == {"category": "electronics", "id": "1234"} + + # Match with all query params + params = template.matches("test://electronics/1234?filter=new&sort=price&limit=20") + assert params == { + "category": "electronics", + "id": "1234", + "filter": "new", + "sort": "price", + "limit": "20", + } + + # Match with partial query params + params = template.matches("test://electronics/1234?filter=new&sort=price") + assert params == { + "category": "electronics", + "id": "1234", + "filter": "new", + "sort": "price", + } + + # Match with unknown query params - should ignore + params = template.matches("test://electronics/1234?filter=new&unknown=value") + assert params == {"category": "electronics", "id": "1234", "filter": "new"} + + def test_form_style_query_validation(self): + """Test validation of form-style query parameters.""" + + # Valid: query params are subset of optional params + def valid_func(key: str, opt1: str = "default", opt2: int = 10, opt3: bool = False) -> str: + return f"{key}-{opt1}-{opt2}-{opt3}" + + template = ResourceTemplate.from_function( + fn=valid_func, + uri_template="test://{key}", + name="test", + ) + assert template.path_params == {"key"} + assert template.optional_query_params == {"opt1", "opt2", "opt3"} + + @pytest.mark.anyio + async def test_create_resource_with_form_style_query(self): + """Test creating resources with form-style query parameters.""" + + def item_func( + category: str, + id: str, + filter: str = "all", + sort: str = "name", + limit: int = 10, + ) -> dict[str, str | int]: + return { + "category": category, + "id": id, + "filter": filter, + "sort": sort, + "limit": limit, + } + + template = ResourceTemplate.from_function( + fn=item_func, + uri_template="items://{category}/{id}", + name="item", + ) + + # Create with only required params + params = {"category": "electronics", "id": "1234"} + resource = await template.create_resource("items://electronics/1234", params) + result = await resource.read() + assert isinstance(result, str) + assert json.loads(result) == { + "category": "electronics", + "id": "1234", + "filter": "all", + "sort": "name", + "limit": 10, + } + + # Create with all params (limit will be string "20",Pydantic handles conversion) + uri = "items://electronics/1234?filter=new&sort=price&limit=20" + params = { + "category": "electronics", + "id": "1234", + "filter": "new", + "sort": "price", + "limit": "20", # value from URI is a string + } + resource = await template.create_resource(uri, params) + result = await resource.read() + assert isinstance(result, str) + assert json.loads(result) == { + "category": "electronics", + "id": "1234", + "filter": "new", + "sort": "price", + "limit": 20, # Pydantic converted "20" to 20 + } + + @pytest.mark.anyio + async def test_create_resource_optional_param_validation_fallback(self): + """ + Test that if optional parameters fail Pydantic validation, + their default values are used due to the + use_defaults_on_optional_validation_error decorator. + """ + + def func_with_optional_typed_params( + key: str, opt_int: int = 42, opt_bool: bool = True + ) -> dict[str, str | int | bool]: + return {"key": key, "opt_int": opt_int, "opt_bool": opt_bool} + + template = ResourceTemplate.from_function( + fn=func_with_optional_typed_params, + uri_template="test://{key}", + name="test_optional_fallback", + ) + + # Case 1: opt_int is invalid, opt_bool is not provided + # URI like "test://mykey?opt_int=notanint" + params_invalid_int = {"key": "mykey", "opt_int": "notanint"} + resource1 = await template.create_resource("test://mykey?opt_int=notanint", params_invalid_int) + result1_str = await resource1.read() + result1 = json.loads(result1_str) + assert result1["key"] == "mykey" + assert result1["opt_int"] == 42 # Default used + assert result1["opt_bool"] is True # Default used + + # Case 2: opt_bool is invalid, opt_int is valid + # URI like "test://mykey?opt_int=100&opt_bool=notabool" + params_invalid_bool = { + "key": "mykey", + "opt_int": "100", # Valid string for int + "opt_bool": "notabool", + } + resource2 = await template.create_resource("test://mykey?opt_int=100&opt_bool=notabool", params_invalid_bool) + result2_str = await resource2.read() + result2 = json.loads(result2_str) + assert result2["key"] == "mykey" + assert result2["opt_int"] == 100 # Provided valid value used + assert result2["opt_bool"] is True # Default used + + # Case 3: Both opt_int and opt_bool are invalid + # URI like "test://mykey?opt_int=bad&opt_bool=bad" + params_both_invalid = { + "key": "mykey", + "opt_int": "bad", + "opt_bool": "bad", + } + resource3 = await template.create_resource("test://mykey?opt_int=bad&opt_bool=bad", params_both_invalid) + result3_str = await resource3.read() + result3 = json.loads(result3_str) + assert result3["key"] == "mykey" + assert result3["opt_int"] == 42 # Default used + assert result3["opt_bool"] is True # Default used + + # Case 4: Empty value for opt_int (should fall back to default) + # URI like "test://mykey?opt_int=" + params_empty_int = {"key": "mykey"} + resource4 = await template.create_resource("test://mykey?opt_int=", params_empty_int) + result4_str = await resource4.read() + result4 = json.loads(result4_str) + assert result4["key"] == "mykey" + assert result4["opt_int"] == 42 # Default used + assert result4["opt_bool"] is True # Default used + + # Case 5: Empty value for opt_bool (should fall back to default) + # URI like "test://mykey?opt_bool=" + params_empty_bool = {"key": "mykey"} + resource5 = await template.create_resource("test://mykey?opt_bool=", params_empty_bool) + result5_str = await resource5.read() + result5 = json.loads(result5_str) + assert result5["key"] == "mykey" + assert result5["opt_int"] == 42 # Default used + assert result5["opt_bool"] is True # Default used + + # Case 6: Optional string param with empty value, should use default value + def func_opt_str(key: str, opt_s: str = "default_val") -> dict[str, str]: + return {"key": key, "opt_s": opt_s} + + template_str = ResourceTemplate.from_function(fn=func_opt_str, uri_template="test://{key}", name="test_opt_str") + params_empty_str = {"key": "mykey"} + resource6 = await template_str.create_resource("test://mykey?opt_s=", params_empty_str) + result6_str = await resource6.read() + result6 = json.loads(result6_str) + assert result6["key"] == "mykey" + assert result6["opt_s"] == "default_val" # Pydantic allows empty string for str type + + @pytest.mark.anyio + async def test_create_resource_required_param_validation_error(self): + """ + Test that if a required parameter fails Pydantic validation, an error is raised + and not suppressed by the new decorator. + """ + + def func_with_required_typed_param(req_int: int, key: str) -> dict[str, int | str]: + return {"req_int": req_int, "key": key} + + template = ResourceTemplate.from_function( + fn=func_with_required_typed_param, + uri_template="test://{key}/{req_int}", # req_int is part of path + name="test_req_error", + ) + + # req_int is "notanint", which is invalid for int type + params_invalid_req = {"key": "mykey", "req_int": "notanint"} + with pytest.raises(ValueError, match="Error creating resource from template"): + # This ValueError comes from ResourceTemplate.create_resource own try-except + # which catches Pydantic's ValidationError. + await template.create_resource("test://mykey/notanint", params_invalid_req) + class TestResourceTemplateAnnotations: """Test annotations on resource templates.""" @@ -258,3 +606,50 @@ def get_item(item_id: str) -> str: # Verify the resource works correctly content = await resource.read() assert content == "Item 123" + + def test_full_parameter_inference(self): + """Test MCP path and query parameter inference: path, required query, optional query.""" + + # Function under test + def resource_fn( + # Path parameters + user_id: Annotated[int, Path(gt=0, description="User ID")], # explicit Path + region, # inferred path # type: ignore + city: str, # inferred path + file_path: str, # inferred path {file_path:path} + # Required query parameter (no default) + version: int, + # Optional query parameters (defaults or Query(...)) + format: Annotated[str, Query("json", description="Output format")], + include_metadata: bool = False, + tags: list[str] = [], + lang: str = "en", + debug: bool = False, + precision: float = 0.5, + ) -> str: + return f"{user_id}/{region}/{city}/{file_path}" + + # Create resource template + template = ResourceTemplate.from_function( + fn=resource_fn, # type: ignore + uri_template="api://data/{user_id}/{region}/{city}/{file_path:path}", + name="full_resource", + ) + + # --- Assertions --- + + # Path parameters + assert template.path_params == {"user_id", "region", "city", "file_path"} + + # Required query parameters (no default) + assert template.required_query_params == {"version"} + + # Optional query parameters (have default or Query) + assert template.optional_query_params == { + "include_metadata", + "tags", + "format", + "lang", + "debug", + "precision", + } diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index 793dfc324..a60892cf5 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -3,16 +3,19 @@ # pyright: reportMissingParameterType=false # pyright: reportUnknownArgumentType=false # pyright: reportUnknownLambdaType=false -from collections.abc import Callable +from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Annotated, Any, TypedDict import annotated_types import pytest from dirty_equals import IsPartialDict -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError, validate_call -from mcp.server.fastmcp.utilities.func_metadata import func_metadata +from mcp.server.fastmcp.utilities.func_metadata import ( + func_metadata, + use_defaults_on_optional_validation_error, +) from mcp.types import CallToolResult @@ -1182,3 +1185,118 @@ def func_with_reserved_json( assert result["json"] == {"nested": "data"} assert result["model_dump"] == [1, 2, 3] assert result["normal"] == "plain string" + + +# Test functions for use_defaults_on_optional_validation_error decorator + + +def sync_func_for_decorator(req_param: str, opt_int: int = 10, opt_bool: bool = False) -> dict[str, str | int | bool]: + return {"req_param": req_param, "opt_int": opt_int, "opt_bool": opt_bool} + + +async def async_func_for_decorator(req_param: str, opt_int: int = 20, opt_str: str = "default") -> dict[str, str | int]: + return {"req_param": req_param, "opt_int": opt_int, "opt_str": opt_str} + + +class TestUseDefaultsOnOptionalValidationErrorDecorator: + @pytest.fixture + def decorated_sync_func(self) -> Callable[..., dict[str, str | int | bool]]: + # Apply validate_call first, then our decorator + return use_defaults_on_optional_validation_error(validate_call(sync_func_for_decorator)) + + @pytest.fixture + def decorated_async_func(self) -> Callable[..., Awaitable[dict[str, str | int]]]: + # Apply validate_call first, then our decorator + return use_defaults_on_optional_validation_error(validate_call(async_func_for_decorator)) + + def test_sync_all_valid(self, decorated_sync_func: Callable[..., dict[str, str | int | bool]]): + result = decorated_sync_func(req_param="test", opt_int=100, opt_bool=True) + assert result == {"req_param": "test", "opt_int": 100, "opt_bool": True} + + def test_sync_omit_optionals(self, decorated_sync_func: Callable[..., dict[str, str | int | bool]]): + result = decorated_sync_func(req_param="test") + assert result == {"req_param": "test", "opt_int": 10, "opt_bool": False} + + def test_sync_invalid_opt_int(self, decorated_sync_func: Callable[..., dict[str, str | int | bool]]): + # opt_int="bad" should cause ValidationError, decorator catches, uses default 10 + result = decorated_sync_func(req_param="test", opt_int="bad") + assert result == {"req_param": "test", "opt_int": 10, "opt_bool": False} + + def test_sync_invalid_opt_bool(self, decorated_sync_func: Callable[..., dict[str, str | int | bool]]): + # opt_bool="bad" should cause ValidationError, decorator catches, uses default False + result = decorated_sync_func(req_param="test", opt_bool="bad") + assert result == {"req_param": "test", "opt_int": 10, "opt_bool": False} + + def test_sync_invalid_opt_int_and_valid_opt_bool( + self, decorated_sync_func: Callable[..., dict[str, str | int | bool]] + ): + result = decorated_sync_func(req_param="test", opt_int="bad", opt_bool=True) + assert result == {"req_param": "test", "opt_int": 10, "opt_bool": True} + + def test_sync_all_optionals_invalid(self, decorated_sync_func: Callable[..., dict[str, str | int | bool]]): + result = decorated_sync_func(req_param="test", opt_int="bad", opt_bool="bad") + assert result == {"req_param": "test", "opt_int": 10, "opt_bool": False} + + def test_sync_required_param_missing(self, decorated_sync_func: Callable[..., dict[str, str | int | bool]]): + with pytest.raises(ValidationError): + decorated_sync_func(opt_int=100) # Missing req_param + + def test_sync_required_param_invalid(self, decorated_sync_func): + # If req_param itself was typed, e.g., req_param: int, and we passed "bad" + # For this test, sync_func_for_decorator has req_param: str, which is flexible. + # Let's define a quick one for this specific case. + def temp_sync_func(req_int_param: int, opt_str: str = "s") -> dict[str, int | str]: + return {"req_int_param": req_int_param, "opt_str": opt_str} + + decorated_temp_func = use_defaults_on_optional_validation_error(validate_call(temp_sync_func)) + with pytest.raises(ValidationError): + decorated_temp_func(req_int_param="notanint") + + @pytest.mark.anyio + async def test_async_all_valid(self, decorated_async_func: Callable[..., Awaitable[dict[str, str | int]]]): + result = await decorated_async_func(req_param="async_test", opt_int=200, opt_str="custom") + assert result == { + "req_param": "async_test", + "opt_int": 200, + "opt_str": "custom", + } + + @pytest.mark.anyio + async def test_async_omit_optionals(self, decorated_async_func: Callable[..., Awaitable[dict[str, str | int]]]): + result = await decorated_async_func(req_param="async_test") + assert result == { + "req_param": "async_test", + "opt_int": 20, + "opt_str": "default", + } + + @pytest.mark.anyio + async def test_async_invalid_opt_int(self, decorated_async_func: Callable[..., Awaitable[dict[str, str | int]]]): + result = await decorated_async_func(req_param="async_test", opt_int="bad") + assert result == { + "req_param": "async_test", + "opt_int": 20, # Default + "opt_str": "default", + } + + @pytest.mark.anyio + async def test_async_invalid_opt_str_but_is_int( + self, decorated_async_func: Callable[..., Awaitable[dict[str, str | int]]] + ): + # opt_str=123 (int) for str type should cause ValidationError, decorator uses default "default" + # Note: pydantic's validate_call might auto-convert int to str if not in strict mode. + # Let's assume default strictness where int is not directly valid for str. + # If validate_call is not strict, this test might need adjustment or a stricter type. + result = await decorated_async_func(req_param="async_test", opt_str=123) + assert result == { + "req_param": "async_test", + "opt_int": 20, + "opt_str": "default", # Default + } + + @pytest.mark.anyio + async def test_async_required_param_missing( + self, decorated_async_func: Callable[..., Awaitable[dict[str, str | int]]] + ): + with pytest.raises(ValidationError): + await decorated_async_func(opt_int=100) diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 8caa3b1f6..870376152 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,4 +1,5 @@ import base64 +import json from pathlib import Path from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -751,6 +752,41 @@ async def test_file_resource_binary(self, tmp_path: Path): assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() + @pytest.mark.anyio + async def test_resource_with_form_style_query(self): + """Test that resources with form-style query expansion work correctly""" + mcp = FastMCP() + + @mcp.resource("resource://{category}/{id}") + def get_item( + category: str, + id: str, + filter: str = "all", + sort: str = "name", + limit: int = 10, + ) -> str: + return f"Item {id} in {category}, filtered by {filter}, sorted by {sort}, limited to {limit}" + + async with client_session(mcp._mcp_server) as client: + # Test with default values + result = await client.read_resource(AnyUrl("resource://electronics/1234")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Item 1234 in electronics, filtered by all, sorted by name, limited to 10" + + # Test with query parameters + result = await client.read_resource(AnyUrl("resource://electronics/1234?filter=new&sort=price&limit=20")) + assert isinstance(result.contents[0], TextResourceContents) + assert ( + result.contents[0].text == "Item 1234 in electronics, filtered by new, sorted by price, limited to 20" + ) + + # Test with partial query parameters + result = await client.read_resource(AnyUrl("resource://electronics/1234?filter=used")) + assert isinstance(result.contents[0], TextResourceContents) + assert ( + result.contents[0].text == "Item 1234 in electronics, filtered by used, sorted by name, limited to 10" + ) + @pytest.mark.anyio async def test_function_resource(self): mcp = FastMCP() @@ -771,24 +807,15 @@ def get_data() -> str: class TestServerResourceTemplates: - @pytest.mark.anyio - async def test_resource_with_params(self): - """Test that a resource with function parameters raises an error if the URI - parameters don't match""" - mcp = FastMCP() - - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://data") - def get_data_fn(param: str) -> str: - return f"Data: {param}" - @pytest.mark.anyio async def test_resource_with_uri_params(self): """Test that a resource with URI parameters is automatically a template""" mcp = FastMCP() - with pytest.raises(ValueError, match="Mismatch between URI parameters"): + with pytest.raises( + ValueError, + match="Mismatch between URI path parameters .* and required function parameters .*", + ): @mcp.resource("resource://{param}") def get_data() -> str: @@ -817,12 +844,40 @@ def get_data(name: str) -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for test" + @pytest.mark.anyio + async def test_resource_with_optional_params(self): + """Test that resources with optional parameters work correctly""" + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + def get_data_with_options(name: str, format: str = "text", limit: int = 10) -> str: + return f"Data for {name} in {format} format with limit {limit}" + + async with client_session(mcp._mcp_server) as client: + # Test with default values + result = await client.read_resource(AnyUrl("resource://test/data")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Data for test in text format with limit 10" + + # Test with query parameters + result = await client.read_resource(AnyUrl("resource://test/data?format=json&limit=20")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Data for test in json format with limit 20" + + # Test with partial query parameters + result = await client.read_resource(AnyUrl("resource://test/data?format=xml")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Data for test in xml format with limit 10" + @pytest.mark.anyio async def test_resource_mismatched_params(self): """Test that mismatched parameters raise an error""" mcp = FastMCP() - with pytest.raises(ValueError, match="Mismatch between URI parameters"): + with pytest.raises( + ValueError, + match="Mismatch between URI path parameters .* and required function parameters .*", + ): @mcp.resource("resource://{name}/data") def get_data(user: str) -> str: @@ -847,7 +902,10 @@ async def test_resource_multiple_mismatched_params(self): """Test that mismatched parameters raise an error""" mcp = FastMCP() - with pytest.raises(ValueError, match="Mismatch between URI parameters"): + with pytest.raises( + ValueError, + match="Mismatch between URI path parameters .* and required function parameters .*", + ): @mcp.resource("resource://{org}/{repo}/data") def get_data_mismatched(org: str, repo_2: str) -> str: @@ -905,6 +963,116 @@ def get_csv(user: str) -> str: assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "csv for bob" + @pytest.mark.anyio + async def test_resource_optional_param_validation_fallback_and_url_encoding( + self, + ): + """Test handling of optional param validation fallback & URL encoding.""" + mcp = FastMCP() + + @mcp.resource("resource://test_item/{item_id}") + def get_test_item_details( + item_id: str, + name: str = "default_name", + count: int = 0, + active: bool = False, + ) -> dict[str, str | int | bool]: + return { + "item_id": item_id, + "name": name, + "count": count, + "active": active, + } + + async with client_session(mcp._mcp_server) as client: + # 1. All defaults + res1_uri = "resource://test_item/item001" + res1_content_result = await client.read_resource(AnyUrl(res1_uri)) + assert res1_content_result.contents and isinstance(res1_content_result.contents[0], TextResourceContents) + data1 = json.loads(res1_content_result.contents[0].text) + assert data1 == { + "item_id": "item001", + "name": "default_name", + "count": 0, + "active": False, + } + + # 2. Valid optional params (name is URL encoded) + res2_uri = "resource://test_item/item002?name=My%20Product&count=10&active=true" + res2_content_result = await client.read_resource(AnyUrl(res2_uri)) + assert res2_content_result.contents and isinstance(res2_content_result.contents[0], TextResourceContents) + data2 = json.loads(res2_content_result.contents[0].text) + assert data2 == { + "item_id": "item002", + "name": "My Product", # Decoded + "count": 10, + "active": True, + } + + # 3. Invalid 'count' (optional int), valid 'name', 'active' not provided + # count=notanint should make it use default_count = 0 + res3_uri = "resource://test_item/item003?name=Another%20Item&count=notanint" + res3_content_result = await client.read_resource(AnyUrl(res3_uri)) + assert res3_content_result.contents and isinstance(res3_content_result.contents[0], TextResourceContents) + data3 = json.loads(res3_content_result.contents[0].text) + assert data3 == { + "item_id": "item003", + "name": "Another Item", + "count": 0, # Fallback to default + "active": False, # Fallback to default + } + + # 4. Invalid 'active' (optional bool), valid 'count', 'name' not provided + # active=notabool should make it use default_active = False + res4_uri = "resource://test_item/item004?count=50&active=notabool" + res4_content_result = await client.read_resource(AnyUrl(res4_uri)) + assert res4_content_result.contents and isinstance(res4_content_result.contents[0], TextResourceContents) + data4 = json.loads(res4_content_result.contents[0].text) + assert data4 == { + "item_id": "item004", + "name": "default_name", # Fallback to default + "count": 50, + "active": False, # Fallback to default + } + + # 5. Empty value for optional 'name' (string type) + # name= (empty value) should fall back to default + res5_uri = "resource://test_item/item005?name=" + res5_content_result = await client.read_resource(AnyUrl(res5_uri)) + assert res5_content_result.contents and isinstance(res5_content_result.contents[0], TextResourceContents) + data5 = json.loads(res5_content_result.contents[0].text) + assert data5 == { + "item_id": "item005", + "name": "default_name", # Fallback to default + "count": 0, + "active": False, + } + + # 6. Empty value for optional 'count' (int type) + # count= (empty value) should fall back to default + res6_uri = "resource://test_item/item006?count=" + res6_content_result = await client.read_resource(AnyUrl(res6_uri)) + assert res6_content_result.contents and isinstance(res6_content_result.contents[0], TextResourceContents) + data6 = json.loads(res6_content_result.contents[0].text) + assert data6 == { + "item_id": "item006", + "name": "default_name", + "count": 0, # Fallback to default because param is removed by parse_qs + "active": False, + } + + # Test required param failing validation at server level + @mcp.resource("resource://req_fail/{req_id}/details") + def get_req_details(req_id: int, detail_type: str = "summary") -> dict[str, str | int]: + return {"req_id": req_id, "detail_type": detail_type} + + async with client_session(mcp._mcp_server) as client: + invalid_req_uri = "resource://req_fail/notanint/details" + # The FastMCP.read_resource wraps internal errors, + # from template.create_resource, into a ResourceError, as McpError. + with pytest.raises(McpError): + await client.read_resource(AnyUrl(invalid_req_uri)) + class TestContextInjection: """Test context injection in tools, resources, and prompts."""