diff --git a/README.md b/README.md index c7d86e6..cfd2540 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ for building GraphQL servers or integrations into existing web frameworks using | FastAPI | [fastapi](https://github.com/graphql-python/graphql-server/blob/master/docs/fastapi.md) | | Flask | [flask](https://github.com/graphql-python/graphql-server/blob/master/docs/flask.md) | | Litestar | [litestar](https://github.com/graphql-python/graphql-server/blob/master/docs/litestar.md) | +| WebOb | [webob](https://github.com/graphql-python/graphql-server/blob/master/docs/webob.md) | | Quart | [quart](https://github.com/graphql-python/graphql-server/blob/master/docs/quart.md) | | Sanic | [sanic](https://github.com/graphql-python/graphql-server/blob/master/docs/sanic.md) | diff --git a/noxfile.py b/noxfile.py index f84a465..c7cca39 100644 --- a/noxfile.py +++ b/noxfile.py @@ -38,6 +38,7 @@ "django", "fastapi", "flask", + "webob", "quart", "sanic", "litestar", @@ -119,6 +120,7 @@ def tests_starlette(session: Session, gql_core: str) -> None: "channels", "fastapi", "flask", + "webob", "quart", "sanic", "litestar", diff --git a/pyproject.toml b/pyproject.toml index 081666e..ce70fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "A library for creating GraphQL APIs" authors = [{ name = "Syrus Akbary", email = "me@syrusakbary.com" }] license = { text = "MIT" } readme = "README.md" -keywords = ["graphql", "api", "rest", "starlette", "async", "fastapi", "django", "flask", "litestar", "sanic", "channels", "aiohttp", "chalice", "pyright", "mypy", "codeflash"] +keywords = ["graphql", "api", "rest", "starlette", "async", "fastapi", "django", "flask", "litestar", "sanic", "channels", "aiohttp", "chalice", "webob", "pyright", "mypy", "codeflash"] classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", @@ -47,6 +47,7 @@ fastapi = ["fastapi>=0.65.2", "python-multipart>=0.0.7"] chalice = ["chalice~=1.22"] litestar = ["litestar>=2; python_version~='3.10'"] pyinstrument = ["pyinstrument>=4.0.0"] +webob = ["WebOb>=1.8"] [tool.pytest.ini_options] # addopts = "--emoji" @@ -64,6 +65,7 @@ markers = [ "flaky", "flask", "litestar", + "webob", "pydantic", "quart", "relay", diff --git a/src/graphql_server/webob/__init__.py b/src/graphql_server/webob/__init__.py new file mode 100644 index 0000000..61aa119 --- /dev/null +++ b/src/graphql_server/webob/__init__.py @@ -0,0 +1,3 @@ +from .views import GraphQLView + +__all__ = ["GraphQLView"] diff --git a/src/graphql_server/webob/views.py b/src/graphql_server/webob/views.py new file mode 100644 index 0000000..04cb08f --- /dev/null +++ b/src/graphql_server/webob/views.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast +from typing_extensions import TypeGuard + +from webob import Request, Response + +from graphql_server.http import GraphQLRequestData +from graphql_server.http.exceptions import HTTPException +from graphql_server.http.sync_base_view import SyncBaseHTTPView, SyncHTTPRequestAdapter +from graphql_server.http.typevars import Context, RootValue +from graphql_server.http.types import HTTPMethod, QueryParams + +if TYPE_CHECKING: + from graphql.type import GraphQLSchema + from graphql_server.http import GraphQLHTTPResponse + from graphql_server.http.ides import GraphQL_IDE + + +class WebobHTTPRequestAdapter(SyncHTTPRequestAdapter): + def __init__(self, request: Request) -> None: + self.request = request + + @property + def query_params(self) -> QueryParams: + return dict(self.request.GET.items()) + + @property + def body(self) -> Union[str, bytes]: + return self.request.body + + @property + def method(self) -> HTTPMethod: + return cast("HTTPMethod", self.request.method.upper()) + + @property + def headers(self) -> Mapping[str, str]: + return self.request.headers + + @property + def post_data(self) -> Mapping[str, Union[str, bytes]]: + return self.request.POST + + @property + def files(self) -> Mapping[str, Any]: + return { + name: value.file + for name, value in self.request.POST.items() + if hasattr(value, "file") + } + + @property + def content_type(self) -> Optional[str]: + return self.request.content_type + + +class GraphQLView( + SyncBaseHTTPView[Request, Response, Response, Context, RootValue], +): + allow_queries_via_get: bool = True + request_adapter_class = WebobHTTPRequestAdapter + + def __init__( + self, + schema: GraphQLSchema, + graphiql: Optional[bool] = None, + graphql_ide: Optional[GraphQL_IDE] = "graphiql", + allow_queries_via_get: bool = True, + multipart_uploads_enabled: bool = False, + ) -> None: + self.schema = schema + self.allow_queries_via_get = allow_queries_via_get + self.multipart_uploads_enabled = multipart_uploads_enabled + + if graphiql is not None: + warnings.warn( + "The `graphiql` argument is deprecated in favor of `graphql_ide`", + DeprecationWarning, + stacklevel=2, + ) + self.graphql_ide = "graphiql" if graphiql else None + else: + self.graphql_ide = graphql_ide + + def get_root_value(self, request: Request) -> Optional[RootValue]: + return None + + def get_context(self, request: Request, response: Response) -> Context: + return {"request": request, "response": response} # type: ignore + + def get_sub_response(self, request: Request) -> Response: + return Response(status=200, content_type="application/json") + + def create_response( + self, + response_data: GraphQLHTTPResponse, + sub_response: Response, + is_strict: bool, + ) -> Response: + sub_response.text = self.encode_json(response_data) + sub_response.content_type = ( + "application/graphql-response+json" if is_strict else "application/json" + ) + return sub_response + + def render_graphql_ide( + self, request: Request, request_data: GraphQLRequestData + ) -> Response: + return Response( + text=request_data.to_template_string(self.graphql_ide_html), + content_type="text/html", + status=200, + ) + + def dispatch_request(self, request: Request) -> Response: + try: + return self.run(request=request) + except HTTPException as e: + return Response(text=e.reason, status=e.status_code) + + +__all__ = ["GraphQLView"] diff --git a/src/tests/http/clients/webob.py b/src/tests/http/clients/webob.py new file mode 100644 index 0000000..26e5fc1 --- /dev/null +++ b/src/tests/http/clients/webob.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import asyncio +import contextvars +import functools +import json +import urllib.parse +from io import BytesIO +from typing import Any, Optional, Union +from typing_extensions import Literal + +from graphql import ExecutionResult +from webob import Request, Response + +from graphql_server.http import GraphQLHTTPResponse +from graphql_server.http.ides import GraphQL_IDE +from graphql_server.webob import GraphQLView as BaseGraphQLView +from tests.http.context import get_context +from tests.views.schema import Query, schema + +from .base import JSON, HttpClient, Response as ClientResponse, ResultOverrideFunction + + +class GraphQLView(BaseGraphQLView[dict[str, object], object]): + result_override: ResultOverrideFunction = None + + def get_root_value(self, request: Request) -> Query: + super().get_root_value(request) # for coverage + return Query() + + def get_context(self, request: Request, response: Response) -> dict[str, object]: + context = super().get_context(request, response) + return get_context(context) + + def process_result( + self, request: Request, result: ExecutionResult, strict: bool = False + ) -> GraphQLHTTPResponse: + if self.result_override: + return self.result_override(result) + return super().process_result(request, result, strict) + + +class WebobHttpClient(HttpClient): + def __init__( + self, + graphiql: Optional[bool] = None, + graphql_ide: Optional[GraphQL_IDE] = "graphiql", + allow_queries_via_get: bool = True, + result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, + ) -> None: + self.view = GraphQLView( + schema=schema, + graphiql=graphiql, + graphql_ide=graphql_ide, + allow_queries_via_get=allow_queries_via_get, + multipart_uploads_enabled=multipart_uploads_enabled, + ) + self.view.result_override = result_override + + async def _graphql_request( + self, + method: Literal["get", "post"], + query: Optional[str] = None, + operation_name: Optional[str] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, + extensions: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> ClientResponse: + body = self._build_body( + query=query, + operation_name=operation_name, + variables=variables, + files=files, + method=method, + extensions=extensions, + ) + + data: Union[dict[str, object], str, None] = None + + url = "/graphql" + + if body and files: + body.update({name: (file, name) for name, file in files.items()}) + + if method == "get": + body_encoded = urllib.parse.urlencode(body or {}) + url = f"{url}?{body_encoded}" + else: + if body: + data = body if files else json.dumps(body) + kwargs["body"] = data + + headers = self._get_headers(method=method, headers=headers, files=files) + + return await self.request(url, method, headers=headers, **kwargs) + + def _do_request( + self, + url: str, + method: Literal["get", "post", "patch", "put", "delete"], + headers: Optional[dict[str, str]] = None, + **kwargs: Any, + ) -> ClientResponse: + body = kwargs.get("body", None) + req = Request.blank( + url, method=method.upper(), headers=headers or {}, body=body + ) + resp = self.view.dispatch_request(req) + return ClientResponse( + status_code=resp.status_code, data=resp.body, headers=resp.headers + ) + + async def request( + self, + url: str, + method: Literal["head", "get", "post", "patch", "put", "delete"], + headers: Optional[dict[str, str]] = None, + **kwargs: Any, + ) -> ClientResponse: + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial( + ctx.run, self._do_request, url=url, method=method, headers=headers, **kwargs + ) + return await loop.run_in_executor(None, func_call) # type: ignore + + async def get( + self, url: str, headers: Optional[dict[str, str]] = None + ) -> ClientResponse: + return await self.request(url, "get", headers=headers) + + async def post( + self, + url: str, + data: Optional[bytes] = None, + json: Optional[JSON] = None, + headers: Optional[dict[str, str]] = None, + ) -> ClientResponse: + body = json if json is not None else data + return await self.request(url, "post", headers=headers, body=body) diff --git a/src/tests/http/conftest.py b/src/tests/http/conftest.py index cd8b8b6..2b7a563 100644 --- a/src/tests/http/conftest.py +++ b/src/tests/http/conftest.py @@ -18,6 +18,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: ("DjangoHttpClient", "django", [pytest.mark.django]), ("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]), ("FlaskHttpClient", "flask", [pytest.mark.flask]), + ("WebobHttpClient", "webob", [pytest.mark.webob]), ("QuartHttpClient", "quart", [pytest.mark.quart]), ("SanicHttpClient", "sanic", [pytest.mark.sanic]), ("LitestarHttpClient", "litestar", [pytest.mark.litestar]),