From e752e533407bb1d7432468f049a86346bee2397e Mon Sep 17 00:00:00 2001 From: Diego Restrepo Date: Sun, 10 Aug 2025 00:03:19 -0500 Subject: [PATCH] feat(back): add support for predefined GraphQL queries - Introduced a new command-line option to specify a file containing predefined GraphQL queries. - Enhanced the server lifecycle management to parse and load queries from the specified file. - Updated the tool execution logic to allow for predefined queries, improving flexibility and usability. - Added error handling for query parsing to ensure robustness. --- README.md | 12 +++- mcp_graphql/__init__.py | 10 +++- mcp_graphql/server.py | 127 ++++++++++++++++++++++++++++++++-------- mcp_graphql/types.py | 6 ++ pyproject.toml | 2 +- uv.lock | 2 +- 6 files changed, 131 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 4747285..0357ab1 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ python -m mcp_graphql --api-url="https://api.example.com/graphql" --auth-token=" ### Available options - `--api-url`: GraphQL API URL (required) -- `--auth-token`: Authentication token (optional) +- `--auth-token`: Authentication token (optional, can also be set via `MCP_AUTH_TOKEN` environment variable) - `--auth-type`: Authentication type, default is "Bearer" (optional) - `--auth-headers`: Custom authentication headers in JSON format (optional) @@ -169,6 +169,16 @@ uv sync ruff check . ``` +### Running the server in development mode + +When working locally you can start the MCP GraphQL server with hot-reloading and inspect its tools using the Model Context Protocol Inspector: + +```bash +npx "@modelcontextprotocol/inspector" uv run -n --project $PWD mcp-graphql --api-url http://localhost:3010/graphql +``` + +Replace `http://localhost:3010/graphql` with the URL of your local GraphQL endpoint if it differs. + ## License This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. diff --git a/mcp_graphql/__init__.py b/mcp_graphql/__init__.py index 69f5ac2..250586d 100644 --- a/mcp_graphql/__init__.py +++ b/mcp_graphql/__init__.py @@ -1,5 +1,6 @@ import asyncio import json +from pathlib import Path from typing import Any import click @@ -45,11 +46,17 @@ def convert( 'Custom authentication headers as JSON string (e.g. \'{"Authorization": "Bearer token", "X-API-Key": "key"}\')' # noqa: E501 ), ) +@click.option( + "--queries-file", + type=click.Path(exists=True, dir_okay=False, path_type=Path), + help="Path to a .gql file with predefined GraphQL queries (optional)", +) def main( api_url: str, auth_token: str | None, auth_type: str, auth_headers: dict[str, Any] | None, + queries_file: Path | None, ) -> None: """MCP Graphql Server - Graphql server for MCP""" @@ -62,9 +69,8 @@ def main( # Otherwise use auth_token and auth_type if provided elif auth_token: auth_headers_dict["Authorization"] = f"{auth_type} {auth_token}" - # If no auth is provided, proceed with empty headers - asyncio.run(serve(api_url, auth_headers_dict)) + asyncio.run(serve(api_url, auth_headers_dict, queries_file=queries_file)) if __name__ == "__main__": diff --git a/mcp_graphql/server.py b/mcp_graphql/server.py index 32c2721..68e893b 100644 --- a/mcp_graphql/server.py +++ b/mcp_graphql/server.py @@ -4,9 +4,10 @@ from contextlib import asynccontextmanager from functools import partial from logging import INFO, WARNING, basicConfig, getLogger +from pathlib import Path from typing import Any, cast -from gql import Client +from gql import Client, gql from gql.dsl import DSLField, DSLQuery, DSLSchema, DSLType, dsl_gql from gql.transport.aiohttp import AIOHTTPTransport from graphql import ( @@ -23,9 +24,11 @@ GraphQLObjectType, GraphQLOutputType, GraphQLScalarType, + parse, print_ast, print_type, ) +from graphql.language.ast import OperationDefinitionNode from graphql.pyutils import inspect from graphql.type import GraphQLSchema from mcp import Resource @@ -46,7 +49,7 @@ ServerContext, ) -# Configurar logging +# Configure logging basicConfig( level=INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -69,17 +72,40 @@ async def server_lifespan( server: Server[ServerContext], # noqa: ARG001 api_url: str, auth_headers: dict[str, str], + queries_file: Path | None = None, ) -> AsyncIterator[ServerContext]: """Manage server startup and shutdown lifecycle.""" # Initialize resources on startup transport = AIOHTTPTransport(url=api_url, headers=auth_headers) client = Client(transport=transport, fetch_schema_from_transport=True) + predefined_queries: dict[str, OperationDefinitionNode] = {} # Use the client directly instead of trying to use session as a context manager + if queries_file and queries_file.exists(): + try: + with queries_file.open(encoding="utf-8") as f: + doc_src = f.read() + + # Parse the .gql document and extract every named operation + doc = parse(doc_src) + + for definition in doc.definitions: + if not isinstance(definition, OperationDefinitionNode): + continue + # Only consider operations that have a name; anonymous + # operations cannot be called as tools. + if definition.name is None: + continue + op_name = definition.name.value + predefined_queries[op_name] = definition + except Exception: + logger.exception("Error while parsing predefined queries file") + async with client as session: try: context: ServerContext = { "session": session, "dsl_schema": DSLSchema(session.client.schema or GraphQLSchema()), + "predefined_queries": predefined_queries if predefined_queries else None, } yield context finally: @@ -129,11 +155,6 @@ def convert_type_to_json_schema( # noqa: C901 elif isinstance(gql_type, GraphQLList): # List wrapper inner_schema = convert_type_to_json_schema(gql_type.of_type, max_depth, current_depth) - logger.info( - "inner_schema: %s %s", - gql_type.of_type.__class__.__name__, - json.dumps(inner_schema, indent=2), - ) schema = {"type": "array", "items": inner_schema} elif isinstance(gql_type, GraphQLScalarType): @@ -304,7 +325,6 @@ def build_selection( def get_args_schema(args_map: GraphQLArgumentMap) -> JsonSchema: args_schema: JsonSchema = {"type": "object", "properties": {}, "required": []} for arg_name, arg in args_map.items(): - logger.debug("Converting GraphQL type for %s: %s", arg_name, arg.type) type_schema = convert_type_to_json_schema(arg.type, max_depth=5, current_depth=1) # Remove the "required" flag which was used for tracking is_required = type_schema.pop("required", False) @@ -321,7 +341,6 @@ def get_args_schema(args_map: GraphQLArgumentMap) -> JsonSchema: and not isinstance(args_schema["required"], bool) ): args_schema["required"].append(arg_name) - logger.debug("args_schema: %s", json.dumps(args_schema, indent=2)) return args_schema @@ -331,7 +350,7 @@ async def list_tools_impl(_server: Server[ServerContext]) -> list[Tool]: ds: DSLSchema = ctx.lifespan_context["dsl_schema"] except LookupError as exc: logger.exception( - "Error al obtener el contexto", + "Error obtaining context", ) # Configura el transporte transport = AIOHTTPTransport(url="http://localhost:8080/graphql") @@ -342,23 +361,42 @@ async def list_tools_impl(_server: Server[ServerContext]) -> list[Tool]: if not session.client.schema: raise SchemaRetrievalError from exc ds = DSLSchema(session.client.schema) + + # Access context pieces + predefined_queries: dict[str, OperationDefinitionNode] = ( + ctx.lifespan_context.get("predefined_queries") or {} + ) + tools: list[Tool] = [] - # Establece la sesión del cliente - if ds: - # Accede al esquema dentro de la sesión + # Determine which query names we should expose as tools + if predefined_queries: + query_names = list(predefined_queries.keys()) + else: + if not ds._schema.query_type: + raise QueryTypeNotFoundError + query_names = list(ds._schema.query_type.fields.keys()) + + # Iterate over the selected query names and build Tool objects + if ds and query_names: if not ds._schema.query_type: raise QueryTypeNotFoundError fields: dict[str, GraphQLField] = ds._schema.query_type.fields - for query_name, field in fields.items(): + + for query_definition in predefined_queries.values(): + # Skip if the query does not exist in the schema (e.g. mutation) + query_name = query_definition.selection_set.selections[0].to_dict()["name"]["value"] + if query_name not in fields: + continue + + field = fields[query_name] dsl_field: DSLField = getattr(ds.Query, query_name) return_type_description = inspect(dsl_field.field.type) # Get the arguments schema for this field args_schema = get_args_schema(dsl_field.field.args) - logger.info("args_schema: %s", json.dumps(args_schema, indent=2)) tools.append( Tool( - name=query_name, + name=query_definition.name.value, # type: ignore[union-attr] description=(field.description or f"GraphQL query: {query_name}") + f" (Returns: {return_type_description})", inputSchema=args_schema, # type: ignore[arg-type] @@ -373,17 +411,52 @@ async def call_tool_impl( name: str, arguments: dict[str, Any], ) -> list[mcp_types.TextContent]: - logger.debug("calling tool %s with arguments %s", name, arguments) ctx = _server.request_context + + # Access context pieces session = ctx.lifespan_context["session"] + predefined_queries = ctx.lifespan_context.get("predefined_queries") or {} + # Don't use the session as a context manager, use it directly ds: DSLSchema = ctx.lifespan_context["dsl_schema"] if not ds._schema.query_type: raise QueryTypeNotFoundError fields: dict[str, GraphQLField] = ds._schema.query_type.fields + # --------------------------------------------------------------------- + # 1. If the query comes from the predefined queries file, execute it as-is + # --------------------------------------------------------------------- + if name in predefined_queries: + query_src = predefined_queries[name] + try: + result = await session.execute( + gql(print_ast(query_src)), + variable_values=arguments or None, + ) + return [mcp_types.TextContent(type="text", text=json.dumps(result))] + except Exception as exc: + logger.exception("Error executing predefined query %s", name) + return [ + mcp_types.TextContent( + type="text", + text=f"Error executing query {name}: {exc}", + ), + ] + + # --------------------------------------------------------------------- + # 2. Fallback to dynamic query generation from the schema + # --------------------------------------------------------------------- + + # If the server was started with predefined queries, reject any other + if predefined_queries: + return [ + mcp_types.TextContent( + type="text", + text=f"The query '{name}' is not among the allowed queries.", + ), + ] + max_depth = 5 - logger.debug("calling tool %s with arguments %s", name, arguments) if _query_name := next((_query_name for _query_name in fields if _query_name == name), None): attr: DSLField = getattr(ds.Query, _query_name) @@ -402,7 +475,7 @@ async def call_tool_impl( return [ mcp_types.TextContent( type="text", - text=f"Error: No se pudo determinar el tipo de retorno para {name}", + text=f"Error: Could not determine the return type for {name}", ), ] @@ -414,20 +487,28 @@ async def call_tool_impl( # Build the actual query query_selections = build_selection(ds, return_type, selections) query = dsl_gql(DSLQuery(attr(**arguments).select(*query_selections))) - logger.info("query: %s", print_ast(query)) # # Execute the query result = await session.execute(query) return [mcp_types.TextContent(type="text", text=json.dumps(result))] # Error case - tool not found - return [mcp_types.TextContent(type="text", text="No se encontró la herramienta")] + return [mcp_types.TextContent(type="text", text="Tool not found")] -async def serve(api_url: str, auth_headers: dict[str, str] | None) -> None: +async def serve( + api_url: str, + auth_headers: dict[str, str] | None, + queries_file: Path | None = None, +) -> None: server = Server[ServerContext]( "mcp-graphql", - lifespan=partial(server_lifespan, api_url=api_url, auth_headers=auth_headers or {}), + lifespan=partial( + server_lifespan, + api_url=api_url, + auth_headers=auth_headers or {}, + queries_file=queries_file, + ), ) server.list_tools()(functools.partial(list_tools_impl, server)) diff --git a/mcp_graphql/types.py b/mcp_graphql/types.py index 5edcced..0469869 100644 --- a/mcp_graphql/types.py +++ b/mcp_graphql/types.py @@ -2,6 +2,7 @@ from gql.client import AsyncClientSession from gql.dsl import DSLSchema +from graphql.language.ast import OperationDefinitionNode class ServerContext(TypedDict): @@ -10,6 +11,11 @@ class ServerContext(TypedDict): session: AsyncClientSession dsl_schema: DSLSchema + # Optional mapping of query name -> GraphQL query string when the server is + # started with a predefined queries file. When absent or empty, the server + # will fall back to exposing every query present in the remote schema. + predefined_queries: dict[str, OperationDefinitionNode] | None + class SchemaRetrievalError(Exception): """Exception raised when the GraphQL schema cannot be retrieved.""" diff --git a/pyproject.toml b/pyproject.toml index 4281ee1..8985ea0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcp-graphql" -version = "0.3.3" +version = "0.4.0" description = "MCP server for GraphQL" readme = "README.md" requires-python = ">=3.11" diff --git a/uv.lock b/uv.lock index ea83639..c90c06f 100644 --- a/uv.lock +++ b/uv.lock @@ -387,7 +387,7 @@ cli = [ [[package]] name = "mcp-graphql" -version = "0.3.3" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "aiohttp" },