Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ python -m mcp_graphql --api-url="https://api.example.com/graphql" --auth-token="
### Available options

- `--api-url`: GraphQL API URL (required)
- `--auth-token`: Authentication token (optional)
- `--auth-token`: Authentication token (optional, can also be set via `MCP_AUTH_TOKEN` environment variable)
- `--auth-type`: Authentication type, default is "Bearer" (optional)
- `--auth-headers`: Custom authentication headers in JSON format (optional)

Expand Down Expand Up @@ -169,6 +169,16 @@ uv sync
ruff check .
```

### Running the server in development mode

When working locally you can start the MCP GraphQL server with hot-reloading and inspect its tools using the Model Context Protocol Inspector:

```bash
npx "@modelcontextprotocol/inspector" uv run -n --project $PWD mcp-graphql --api-url http://localhost:3010/graphql
```

Replace `http://localhost:3010/graphql` with the URL of your local GraphQL endpoint if it differs.

## License

This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
Expand Down
10 changes: 8 additions & 2 deletions mcp_graphql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
from pathlib import Path
from typing import Any

import click
Expand Down Expand Up @@ -45,11 +46,17 @@ def convert(
'Custom authentication headers as JSON string (e.g. \'{"Authorization": "Bearer token", "X-API-Key": "key"}\')' # noqa: E501
),
)
@click.option(
"--queries-file",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
help="Path to a .gql file with predefined GraphQL queries (optional)",
)
def main(
api_url: str,
auth_token: str | None,
auth_type: str,
auth_headers: dict[str, Any] | None,
queries_file: Path | None,
) -> None:
"""MCP Graphql Server - Graphql server for MCP"""

Expand All @@ -62,9 +69,8 @@ def main(
# Otherwise use auth_token and auth_type if provided
elif auth_token:
auth_headers_dict["Authorization"] = f"{auth_type} {auth_token}"
# If no auth is provided, proceed with empty headers

asyncio.run(serve(api_url, auth_headers_dict))
asyncio.run(serve(api_url, auth_headers_dict, queries_file=queries_file))


if __name__ == "__main__":
Expand Down
127 changes: 104 additions & 23 deletions mcp_graphql/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from contextlib import asynccontextmanager
from functools import partial
from logging import INFO, WARNING, basicConfig, getLogger
from pathlib import Path
from typing import Any, cast

from gql import Client
from gql import Client, gql
from gql.dsl import DSLField, DSLQuery, DSLSchema, DSLType, dsl_gql
from gql.transport.aiohttp import AIOHTTPTransport
from graphql import (
Expand All @@ -23,9 +24,11 @@
GraphQLObjectType,
GraphQLOutputType,
GraphQLScalarType,
parse,
print_ast,
print_type,
)
from graphql.language.ast import OperationDefinitionNode
from graphql.pyutils import inspect
from graphql.type import GraphQLSchema
from mcp import Resource
Expand All @@ -46,7 +49,7 @@
ServerContext,
)

# Configurar logging
# Configure logging
basicConfig(
level=INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
Expand All @@ -69,17 +72,40 @@ async def server_lifespan(
server: Server[ServerContext], # noqa: ARG001
api_url: str,
auth_headers: dict[str, str],
queries_file: Path | None = None,
) -> AsyncIterator[ServerContext]:
"""Manage server startup and shutdown lifecycle."""
# Initialize resources on startup
transport = AIOHTTPTransport(url=api_url, headers=auth_headers)
client = Client(transport=transport, fetch_schema_from_transport=True)
predefined_queries: dict[str, OperationDefinitionNode] = {}
# Use the client directly instead of trying to use session as a context manager
if queries_file and queries_file.exists():
try:
with queries_file.open(encoding="utf-8") as f:
doc_src = f.read()

# Parse the .gql document and extract every named operation
doc = parse(doc_src)

for definition in doc.definitions:
if not isinstance(definition, OperationDefinitionNode):
continue
# Only consider operations that have a name; anonymous
# operations cannot be called as tools.
if definition.name is None:
continue
op_name = definition.name.value
predefined_queries[op_name] = definition
except Exception:
logger.exception("Error while parsing predefined queries file")

async with client as session:
try:
context: ServerContext = {
"session": session,
"dsl_schema": DSLSchema(session.client.schema or GraphQLSchema()),
"predefined_queries": predefined_queries if predefined_queries else None,
}
yield context
finally:
Expand Down Expand Up @@ -129,11 +155,6 @@ def convert_type_to_json_schema( # noqa: C901
elif isinstance(gql_type, GraphQLList):
# List wrapper
inner_schema = convert_type_to_json_schema(gql_type.of_type, max_depth, current_depth)
logger.info(
"inner_schema: %s %s",
gql_type.of_type.__class__.__name__,
json.dumps(inner_schema, indent=2),
)
schema = {"type": "array", "items": inner_schema}

elif isinstance(gql_type, GraphQLScalarType):
Expand Down Expand Up @@ -304,7 +325,6 @@ def build_selection(
def get_args_schema(args_map: GraphQLArgumentMap) -> JsonSchema:
args_schema: JsonSchema = {"type": "object", "properties": {}, "required": []}
for arg_name, arg in args_map.items():
logger.debug("Converting GraphQL type for %s: %s", arg_name, arg.type)
type_schema = convert_type_to_json_schema(arg.type, max_depth=5, current_depth=1)
# Remove the "required" flag which was used for tracking
is_required = type_schema.pop("required", False)
Expand All @@ -321,7 +341,6 @@ def get_args_schema(args_map: GraphQLArgumentMap) -> JsonSchema:
and not isinstance(args_schema["required"], bool)
):
args_schema["required"].append(arg_name)
logger.debug("args_schema: %s", json.dumps(args_schema, indent=2))
return args_schema


Expand All @@ -331,7 +350,7 @@ async def list_tools_impl(_server: Server[ServerContext]) -> list[Tool]:
ds: DSLSchema = ctx.lifespan_context["dsl_schema"]
except LookupError as exc:
logger.exception(
"Error al obtener el contexto",
"Error obtaining context",
)
# Configura el transporte
transport = AIOHTTPTransport(url="http://localhost:8080/graphql")
Expand All @@ -342,23 +361,42 @@ async def list_tools_impl(_server: Server[ServerContext]) -> list[Tool]:
if not session.client.schema:
raise SchemaRetrievalError from exc
ds = DSLSchema(session.client.schema)

# Access context pieces
predefined_queries: dict[str, OperationDefinitionNode] = (
ctx.lifespan_context.get("predefined_queries") or {}
)

tools: list[Tool] = []

# Establece la sesión del cliente
if ds:
# Accede al esquema dentro de la sesión
# Determine which query names we should expose as tools
if predefined_queries:
query_names = list(predefined_queries.keys())
else:
if not ds._schema.query_type:
raise QueryTypeNotFoundError
query_names = list(ds._schema.query_type.fields.keys())

# Iterate over the selected query names and build Tool objects
if ds and query_names:
if not ds._schema.query_type:
raise QueryTypeNotFoundError
fields: dict[str, GraphQLField] = ds._schema.query_type.fields
for query_name, field in fields.items():

for query_definition in predefined_queries.values():
# Skip if the query does not exist in the schema (e.g. mutation)
query_name = query_definition.selection_set.selections[0].to_dict()["name"]["value"]
if query_name not in fields:
continue

field = fields[query_name]
dsl_field: DSLField = getattr(ds.Query, query_name)
return_type_description = inspect(dsl_field.field.type)
# Get the arguments schema for this field
args_schema = get_args_schema(dsl_field.field.args)
logger.info("args_schema: %s", json.dumps(args_schema, indent=2))
tools.append(
Tool(
name=query_name,
name=query_definition.name.value, # type: ignore[union-attr]
description=(field.description or f"GraphQL query: {query_name}")
+ f" (Returns: {return_type_description})",
inputSchema=args_schema, # type: ignore[arg-type]
Expand All @@ -373,17 +411,52 @@ async def call_tool_impl(
name: str,
arguments: dict[str, Any],
) -> list[mcp_types.TextContent]:
logger.debug("calling tool %s with arguments %s", name, arguments)
ctx = _server.request_context

# Access context pieces
session = ctx.lifespan_context["session"]
predefined_queries = ctx.lifespan_context.get("predefined_queries") or {}

# Don't use the session as a context manager, use it directly
ds: DSLSchema = ctx.lifespan_context["dsl_schema"]
if not ds._schema.query_type:
raise QueryTypeNotFoundError
fields: dict[str, GraphQLField] = ds._schema.query_type.fields

# ---------------------------------------------------------------------
# 1. If the query comes from the predefined queries file, execute it as-is
# ---------------------------------------------------------------------
if name in predefined_queries:
query_src = predefined_queries[name]
try:
result = await session.execute(
gql(print_ast(query_src)),
variable_values=arguments or None,
)
return [mcp_types.TextContent(type="text", text=json.dumps(result))]
except Exception as exc:
logger.exception("Error executing predefined query %s", name)
return [
mcp_types.TextContent(
type="text",
text=f"Error executing query {name}: {exc}",
),
]

# ---------------------------------------------------------------------
# 2. Fallback to dynamic query generation from the schema
# ---------------------------------------------------------------------

# If the server was started with predefined queries, reject any other
if predefined_queries:
return [
mcp_types.TextContent(
type="text",
text=f"The query '{name}' is not among the allowed queries.",
),
]

max_depth = 5
logger.debug("calling tool %s with arguments %s", name, arguments)
if _query_name := next((_query_name for _query_name in fields if _query_name == name), None):
attr: DSLField = getattr(ds.Query, _query_name)

Expand All @@ -402,7 +475,7 @@ async def call_tool_impl(
return [
mcp_types.TextContent(
type="text",
text=f"Error: No se pudo determinar el tipo de retorno para {name}",
text=f"Error: Could not determine the return type for {name}",
),
]

Expand All @@ -414,20 +487,28 @@ async def call_tool_impl(
# Build the actual query
query_selections = build_selection(ds, return_type, selections)
query = dsl_gql(DSLQuery(attr(**arguments).select(*query_selections)))
logger.info("query: %s", print_ast(query))

# # Execute the query
result = await session.execute(query)
return [mcp_types.TextContent(type="text", text=json.dumps(result))]

# Error case - tool not found
return [mcp_types.TextContent(type="text", text="No se encontró la herramienta")]
return [mcp_types.TextContent(type="text", text="Tool not found")]


async def serve(api_url: str, auth_headers: dict[str, str] | None) -> None:
async def serve(
api_url: str,
auth_headers: dict[str, str] | None,
queries_file: Path | None = None,
) -> None:
server = Server[ServerContext](
"mcp-graphql",
lifespan=partial(server_lifespan, api_url=api_url, auth_headers=auth_headers or {}),
lifespan=partial(
server_lifespan,
api_url=api_url,
auth_headers=auth_headers or {},
queries_file=queries_file,
),
)

server.list_tools()(functools.partial(list_tools_impl, server))
Expand Down
6 changes: 6 additions & 0 deletions mcp_graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from gql.client import AsyncClientSession
from gql.dsl import DSLSchema
from graphql.language.ast import OperationDefinitionNode


class ServerContext(TypedDict):
Expand All @@ -10,6 +11,11 @@ class ServerContext(TypedDict):
session: AsyncClientSession
dsl_schema: DSLSchema

# Optional mapping of query name -> GraphQL query string when the server is
# started with a predefined queries file. When absent or empty, the server
# will fall back to exposing every query present in the remote schema.
predefined_queries: dict[str, OperationDefinitionNode] | None


class SchemaRetrievalError(Exception):
"""Exception raised when the GraphQL schema cannot be retrieved."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mcp-graphql"
version = "0.3.3"
version = "0.4.0"
description = "MCP server for GraphQL"
readme = "README.md"
requires-python = ">=3.11"
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.