diff --git a/docs/Usage/Request.md b/docs/Usage/Request.md index 9042f6e3..e7765322 100644 --- a/docs/Usage/Request.md +++ b/docs/Usage/Request.md @@ -103,6 +103,88 @@ def get_book(raw: BookRaw): return "ok" ``` +## Multiple content types in the request body + +```python +from typing import Union + +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.dog+json" + } + } + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.cat+json" + } + } + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/bson" + } + } + + +class ContentTypeModel(BaseModel): + model_config = { + "openapi_extra": { + "content_type": "text/csv" + } + } + + +@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel | BsonModel}) +def index_a(body: DogBody | CatBody | ContentTypeModel | BsonModel): + """ + multiple content types examples. + + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} +``` + +The effect in swagger: + +![](../assets/Snipaste_2025-01-14_10-44-00.png) + + ## Request model First, you need to define a [pydantic](https://github.com/pydantic/pydantic) model: @@ -125,7 +207,7 @@ class BookQuery(BaseModel): author: str = Field(None, description='Author', json_schema_extra={"deprecated": True}) ``` -Magic: +The effect in swagger: ![](../assets/Snipaste_2022-09-04_10-10-03.png) diff --git a/docs/Usage/Response.md b/docs/Usage/Response.md index 8d8300b9..aad1c938 100644 --- a/docs/Usage/Response.md +++ b/docs/Usage/Response.md @@ -56,6 +56,122 @@ def hello(path: HelloPath): ![image-20210526104627124](../assets/image-20210526104627124.png) +*Sometimes you may need more description fields about the response, such as description, headers and links. + +You can use the following form: + +```python +@app.get( + "/test", + responses={ + "201": { + "model": BaseResponse, + "description": "Custom description", + "headers": { + "location": { + "description": "URL of the new resource", + "schema": {"type": "string"} + } + }, + "links": { + "dummy": { + "description": "dummy link" + } + } + } + } + ) + def endpoint_test(): + ... +``` + +The effect in swagger: + +![](../assets/Snipaste_2025-01-14_11-08-40.png) + + +## Multiple content types in the responses + +```python +from typing import Union + +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.dog+json" + } + } + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.cat+json" + } + } + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/bson" + } + } + + +class ContentTypeModel(BaseModel): + model_config = { + "openapi_extra": { + "content_type": "text/csv" + } + } + + +@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel | BsonModel}) +def index_a(body: DogBody | CatBody | ContentTypeModel | BsonModel): + """ + multiple content types examples. + + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} +``` + +The effect in swagger: + +![](../assets/Snipaste_2025-01-14_10-49-19.png) + + ## More information about OpenAPI responses - [OpenAPI Responses Object](https://spec.openapis.org/oas/v3.1.0#responses-object), it includes the Response Object. diff --git a/docs/Usage/Route_Operation.md b/docs/Usage/Route_Operation.md index 2db8e3c0..30ed972e 100644 --- a/docs/Usage/Route_Operation.md +++ b/docs/Usage/Route_Operation.md @@ -287,6 +287,29 @@ class BookListAPIView: app.register_api_view(api_view) ``` +## request_body_description + +A brief description of the request body. + +```python +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + +@app.post( + "/", + request_body_description="A brief description of the request body." +) +def create_book(body: Bookbody): + ... +``` + +![](../assets/Snipaste_2025-01-14_10-56-40.png) + +## request_body_required + +Determines if the request body is required in the request. + ## doc_ui You can pass `doc_ui=False` to disable the `OpenAPI spec` when init `OpenAPI `. diff --git a/docs/assets/Snipaste_2025-01-14_10-44-00.png b/docs/assets/Snipaste_2025-01-14_10-44-00.png new file mode 100644 index 00000000..d716d979 Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_10-44-00.png differ diff --git a/docs/assets/Snipaste_2025-01-14_10-49-19.png b/docs/assets/Snipaste_2025-01-14_10-49-19.png new file mode 100644 index 00000000..f1bb52e7 Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_10-49-19.png differ diff --git a/docs/assets/Snipaste_2025-01-14_10-56-40.png b/docs/assets/Snipaste_2025-01-14_10-56-40.png new file mode 100644 index 00000000..e66488dc Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_10-56-40.png differ diff --git a/docs/assets/Snipaste_2025-01-14_11-08-40.png b/docs/assets/Snipaste_2025-01-14_11-08-40.png new file mode 100644 index 00000000..dd537d87 Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_11-08-40.png differ diff --git a/examples/multi_content_type.py b/examples/multi_content_type.py new file mode 100644 index 00000000..28a2f1fa --- /dev/null +++ b/examples/multi_content_type.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# @Author : llc +# @Time : 2024/12/27 15:30 +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.dog+json" + } + } + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.cat+json" + } + } + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/bson" + } + } + + +class ContentTypeModel(BaseModel): + model_config = { + "openapi_extra": { + "content_type": "text/csv" + } + } + + +@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel | BsonModel}) +def index_a(body: DogBody | CatBody | ContentTypeModel | BsonModel): + """ + multiple content types examples. + + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + from bson import BSON + + obj = BSON(body.data).decode() + new_body = body.model_validate(obj=obj) + print(new_body) + else: + # DogBody or CatBody + ... + return {"hello": "world"} + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/flask_openapi3/blueprint.py b/flask_openapi3/blueprint.py index b0b85c39..4c95efdf 100644 --- a/flask_openapi3/blueprint.py +++ b/flask_openapi3/blueprint.py @@ -121,6 +121,8 @@ def _collect_openapi_info( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, method: str = HTTPMethod.GET ) -> ParametersTuple: @@ -140,6 +142,8 @@ def _collect_openapi_info( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ if self.doc_ui is True and doc_ui is True: @@ -193,6 +197,11 @@ def _collect_openapi_info( parse_method(uri, method, self.paths, operation) # Parse parameters - return parse_parameters(func, components_schemas=self.components_schemas, operation=operation) + return parse_parameters( + func, components_schemas=self.components_schemas, + operation=operation, + request_body_description=request_body_description, + request_body_required=request_body_required + ) else: return parse_parameters(func, doc_ui=False) diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index 5c1c91e8..97c995ae 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -380,6 +380,8 @@ def _collect_openapi_info( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, method: str = HTTPMethod.GET ) -> ParametersTuple: @@ -399,6 +401,8 @@ def _collect_openapi_info( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. method: HTTP method for the operation. Defaults to GET. """ @@ -450,6 +454,11 @@ def _collect_openapi_info( parse_method(uri, method, self.paths, operation) # Parse parameters - return parse_parameters(func, components_schemas=self.components_schemas, operation=operation) + return parse_parameters( + func, components_schemas=self.components_schemas, + operation=operation, + request_body_description=request_body_description, + request_body_required=request_body_required + ) else: return parse_parameters(func, doc_ui=False) diff --git a/flask_openapi3/request.py b/flask_openapi3/request.py index f6c1d39e..11aab8bb 100644 --- a/flask_openapi3/request.py +++ b/flask_openapi3/request.py @@ -3,13 +3,22 @@ # @Time : 2022/4/1 16:54 import json from json import JSONDecodeError -from typing import Any, Type, Optional + +from typing import Any, Type, Optional, get_origin, get_args, Union + +try: + from types import UnionType # type: ignore +except ImportError: # pragma: no cover + # python < 3.10 + UnionType = type(Union) # type: ignore from flask import request, current_app, abort -from pydantic import ValidationError, BaseModel +from pydantic import ValidationError, BaseModel, RootModel from pydantic.fields import FieldInfo from werkzeug.datastructures.structures import MultiDict +from flask_openapi3.utils import is_application_json + def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo): if model_field_value.alias and model.model_config.get("populate_by_name"): @@ -42,10 +51,8 @@ def _get_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, mo def _validate_header(header: Type[BaseModel], func_kwargs: dict): request_headers = dict(request.headers) header_dict = {} - model_properties = header.model_json_schema().get("properties", {}) for model_field_key, model_field_value in header.model_fields.items(): key_title = model_field_key.replace("_", "-").title() - model_field_schema = model_properties.get(model_field_value.alias or model_field_key) if model_field_value.alias and header.model_config.get("populate_by_name"): key = model_field_value.alias key_alias_title = model_field_value.alias.replace("_", "-").title() @@ -56,11 +63,9 @@ def _validate_header(header: Type[BaseModel], func_kwargs: dict): value = request_headers.get(key_alias_title) else: key = model_field_key - value = request_headers[key_title] + value = request_headers.get(key_title) if value is not None: header_dict[key] = value - if model_field_schema.get("type") == "null": - header_dict[key] = value # type:ignore # extra keys for key, value in request_headers.items(): if key not in header_dict.keys(): @@ -138,12 +143,20 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict): def _validate_body(body: Type[BaseModel], func_kwargs: dict): - obj = request.get_json(silent=True) - if isinstance(obj, str): - body_model = body.model_validate_json(json_data=obj) + if is_application_json(request.mimetype): + if get_origin(body) in (Union, UnionType): + root_model_list = [model for model in get_args(body)] + Body = RootModel[Union[tuple(root_model_list)]] # type: ignore + else: + Body = body # type: ignore + obj = request.get_json(silent=True) + if isinstance(obj, str): + body_model = Body.model_validate_json(json_data=obj) + else: + body_model = Body.model_validate(obj=obj) + func_kwargs["body"] = body_model else: - body_model = body.model_validate(obj=obj) - func_kwargs["body"] = body_model + func_kwargs["body"] = request def _validate_request( diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index 0e38e5ac..d7e3669a 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -32,13 +32,15 @@ def _collect_openapi_info( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, method: str = HTTPMethod.GET ) -> ParametersTuple: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def register_api(self, api) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def _add_url_rule( self, @@ -48,7 +50,7 @@ def _add_url_rule( provide_automatic_options=None, **options, ) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover @staticmethod def create_view_func( @@ -199,6 +201,8 @@ def post( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, **options: Any ) -> Callable: @@ -218,6 +222,8 @@ def post( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -236,6 +242,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.POST ) @@ -262,6 +270,8 @@ def put( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, **options: Any ) -> Callable: @@ -281,6 +291,8 @@ def put( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -299,6 +311,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.PUT ) @@ -325,6 +339,8 @@ def delete( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, **options: Any ) -> Callable: @@ -344,6 +360,8 @@ def delete( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -362,6 +380,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.DELETE ) @@ -388,6 +408,8 @@ def patch( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, **options: Any ) -> Callable: @@ -407,6 +429,8 @@ def patch( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -425,6 +449,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.PATCH ) diff --git a/flask_openapi3/types.py b/flask_openapi3/types.py index d7e2b437..90de0d49 100644 --- a/flask_openapi3/types.py +++ b/flask_openapi3/types.py @@ -2,14 +2,16 @@ # @Author : llc # @Time : 2023/7/9 15:25 from http import HTTPStatus -from typing import Union, Type, Any, Optional +from typing import Union, Type, Any, Optional, TypeVar from pydantic import BaseModel from .models import RawModel from .models import SecurityScheme -_ResponseDictValue = Union[Type[BaseModel], dict[Any, Any], None] +_MultiBaseModel = TypeVar("_MultiBaseModel", bound=Type[BaseModel]) + +_ResponseDictValue = Union[Type[BaseModel], _MultiBaseModel, dict[Any, Any], None] ResponseDict = dict[Union[str, int, HTTPStatus], _ResponseDictValue] diff --git a/flask_openapi3/utils.py b/flask_openapi3/utils.py index 205a8cfc..817a0398 100644 --- a/flask_openapi3/utils.py +++ b/flask_openapi3/utils.py @@ -7,7 +7,14 @@ import sys from enum import Enum from http import HTTPStatus -from typing import get_type_hints, Type, Callable, Optional, Any, DefaultDict +from typing import Type, Callable, Optional, Any, DefaultDict, Union +from typing import get_args, get_origin, get_type_hints + +try: + from types import UnionType # type: ignore +except ImportError: # pragma: no cover + # python < 3.10 + UnionType = Union # type: ignore from flask import make_response, current_app from flask.wrappers import Response as FlaskResponse @@ -268,6 +275,11 @@ def parse_form( ) -> tuple[dict[str, MediaType], dict]: """Parses a form model and returns a list of parameters and component schemas.""" schema = get_model_schema(form) + + model_config: DefaultDict[str, Any] = form.model_config # type: ignore + openapi_extra = model_config.get("openapi_extra", {}) + content_type = openapi_extra.get("content_type", "multipart/form-data") + components_schemas = dict() properties = schema.get("properties", {}) @@ -280,14 +292,22 @@ def parse_form( for k, v in properties.items(): if v.get("type") == "array": encoding[k] = Encoding(style="form", explode=True) - content = { - "multipart/form-data": MediaType( - schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}), - ) - } + + media_type = MediaType(**{"schema": Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"})}) + + if openapi_extra: + openapi_extra_keys = openapi_extra.keys() + if "example" in openapi_extra_keys: + media_type.example = openapi_extra.get("example") + if "examples" in openapi_extra_keys: + media_type.examples = openapi_extra.get("examples") + if "encoding" in openapi_extra_keys: + media_type.encoding = openapi_extra.get("encoding") + if encoding: - content["multipart/form-data"].encoding = encoding + media_type.encoding = encoding + content = {content_type: media_type} # Parse definitions definitions = schema.get("$defs", {}) for name, value in definitions.items(): @@ -300,22 +320,49 @@ def parse_body( body: Type[BaseModel], ) -> tuple[dict[str, MediaType], dict]: """Parses a body model and returns a list of parameters and component schemas.""" - schema = get_model_schema(body) - components_schemas = dict() - original_title = schema.get("title") or body.__name__ - title = normalize_name(original_title) - components_schemas[title] = Schema(**schema) - content = { - "application/json": MediaType( - schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}) - ) - } + content = {} + components_schemas = {} - # Parse definitions - definitions = schema.get("$defs", {}) - for name, value in definitions.items(): - components_schemas[name] = Schema(**value) + def _parse_body(_model): + model_config: DefaultDict[str, Any] = _model.model_config # type: ignore + openapi_extra = model_config.get("openapi_extra", {}) + content_type = openapi_extra.get("content_type", "application/json") + + if not is_application_json(content_type): + content_schema = openapi_extra.get("content_schema", {"type": DataType.STRING}) + content[content_type] = MediaType(**{"schema": content_schema}) + return + + schema = get_model_schema(_model) + + original_title = schema.get("title") or _model.__name__ + title = normalize_name(original_title) + components_schemas[title] = Schema(**schema) + + media_type = MediaType(**{"schema": Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"})}) + + if openapi_extra: + openapi_extra_keys = openapi_extra.keys() + if "example" in openapi_extra_keys: + media_type.example = openapi_extra.get("example") + if "examples" in openapi_extra_keys: + media_type.examples = openapi_extra.get("examples") + if "encoding" in openapi_extra_keys: + media_type.encoding = openapi_extra.get("encoding") + + content[content_type] = media_type + + # Parse definitions + definitions = schema.get("$defs", {}) + for name, value in definitions.items(): + components_schemas[name] = Schema(**value) + + if get_origin(body) in (Union, UnionType): + for model in get_args(body): + _parse_body(model) + else: + _parse_body(body) return content, components_schemas @@ -325,54 +372,86 @@ def get_responses( components_schemas: dict, operation: Operation ) -> None: - _responses = {} - _schemas = {} + _responses: dict = {} + _schemas: dict = {} + + def _parse_response(_key, _model): + model_config: DefaultDict[str, Any] = _model.model_config # type: ignore + openapi_extra = model_config.get("openapi_extra", {}) + content_type = openapi_extra.get("content_type", "application/json") + + if not is_application_json(content_type): + content_schema = openapi_extra.get("content_schema", {"type": DataType.STRING}) + media_type = MediaType(**{"schema": content_schema}) + if _responses.get(_key): + _responses[_key].content[content_type] = media_type + else: + _responses[_key] = Response( + description=HTTP_STATUS.get(_key, ""), + content={content_type: media_type} + ) + return + + schema = get_model_schema(_model, mode="serialization") + # OpenAPI 3 support ^[a-zA-Z0-9\.\-_]+$ so we should normalize __name__ + original_title = schema.get("title") or _model.__name__ + name = normalize_name(original_title) + + media_type = MediaType(**{"schema": Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{name}"})}) + + if openapi_extra: + openapi_extra_keys = openapi_extra.keys() + if "example" in openapi_extra_keys: + media_type.example = openapi_extra.get("example") + if "examples" in openapi_extra_keys: + media_type.examples = openapi_extra.get("examples") + if "encoding" in openapi_extra_keys: + media_type.encoding = openapi_extra.get("encoding") + if _responses.get(_key): + _responses[_key].content[content_type] = media_type + else: + _responses[_key] = Response( + description=HTTP_STATUS.get(_key, ""), + content={content_type: media_type} + ) + + _schemas[name] = Schema(**schema) + definitions = schema.get("$defs") + if definitions: + # Add schema definitions to _schemas + for name, value in definitions.items(): + _schemas[normalize_name(name)] = Schema(**value) for key, response in responses.items(): - if response is None: + if isinstance(response, dict) and "model" in response: + response_model = response.get("model") + response_description = response.get("description") + response_headers = response.get("headers") + response_links = response.get("links") + else: + response_model = response + response_description = None + response_headers = None + response_links = None + + if response_model is None: # If the response is None, it means HTTP status code "204" (No Content) _responses[key] = Response(description=HTTP_STATUS.get(key, "")) - elif isinstance(response, dict): - response["description"] = response.get("description", HTTP_STATUS.get(key, "")) - _responses[key] = Response(**response) + elif isinstance(response_model, dict): + response_model["description"] = response_model.get("description", HTTP_STATUS.get(key, "")) + _responses[key] = Response(**response_model) + elif get_origin(response_model) in [UnionType, Union]: + for model in get_args(response_model): + _parse_response(key, model) else: - # OpenAPI 3 support ^[a-zA-Z0-9\.\-_]+$ so we should normalize __name__ - schema = get_model_schema(response, mode="serialization") - original_title = schema.get("title") or response.__name__ - name = normalize_name(original_title) - _responses[key] = Response( - description=HTTP_STATUS.get(key, ""), - content={ - "application/json": MediaType( - schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{name}"}) - )}) - - model_config: DefaultDict[str, Any] = response.model_config # type: ignore - openapi_extra = model_config.get("openapi_extra", {}) - if openapi_extra: - openapi_extra_keys = openapi_extra.keys() - # Add additional information from model_config to the response - if "description" in openapi_extra_keys: - _responses[key].description = openapi_extra.get("description") - if "headers" in openapi_extra_keys: - _responses[key].headers = openapi_extra.get("headers") - if "links" in openapi_extra_keys: - _responses[key].links = openapi_extra.get("links") - _content = _responses[key].content - if "example" in openapi_extra_keys: - _content["application/json"].example = openapi_extra.get("example") # type: ignore - if "examples" in openapi_extra_keys: - _content["application/json"].examples = openapi_extra.get("examples") # type: ignore - if "encoding" in openapi_extra_keys: - _content["application/json"].encoding = openapi_extra.get("encoding") # type: ignore - _content.update(openapi_extra.get("content", {})) # type: ignore - - _schemas[name] = Schema(**schema) - definitions = schema.get("$defs") - if definitions: - # Add schema definitions to _schemas - for name, value in definitions.items(): - _schemas[normalize_name(name)] = Schema(**value) + _parse_response(key, response_model) + + if response_description is not None: + _responses[key].description = response_description + if response_headers is not None: + _responses[key].headers = response_headers + if response_links is not None: + _responses[key].links = response_links components_schemas.update(**_schemas) operation.responses = _responses @@ -413,6 +492,8 @@ def parse_parameters( *, components_schemas: Optional[dict] = None, operation: Optional[Operation] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True, ) -> ParametersTuple: """ @@ -423,6 +504,8 @@ def parse_parameters( func: The function to parse the parameters from. components_schemas: Dictionary to store the parsed components schemas (default: None). operation: Operation object to populate with parsed parameters (default: None). + request_body_description: A brief description of the request body (default: None). + request_body_required: Determines if the request body is required in the request (default: True). doc_ui: Flag indicating whether to return types for documentation UI (default: True). Returns: @@ -481,51 +564,31 @@ def parse_parameters( _content, _components_schemas = parse_form(form) components_schemas.update(**_components_schemas) request_body = RequestBody(content=_content, required=True) - model_config: DefaultDict[str, Any] = form.model_config # type: ignore - openapi_extra = model_config.get("openapi_extra", {}) - if openapi_extra: - openapi_extra_keys = openapi_extra.keys() - if "description" in openapi_extra_keys: - request_body.description = openapi_extra.get("description") - if "example" in openapi_extra_keys: - request_body.content["multipart/form-data"].example = openapi_extra.get("example") - if "examples" in openapi_extra_keys: - request_body.content["multipart/form-data"].examples = openapi_extra.get("examples") - if "encoding" in openapi_extra_keys: - request_body.content["multipart/form-data"].encoding = openapi_extra.get("encoding") + if request_body_description: + request_body.description = request_body_description + request_body.required = request_body_required operation.requestBody = request_body if body: _content, _components_schemas = parse_body(body) components_schemas.update(**_components_schemas) request_body = RequestBody(content=_content, required=True) - model_config: DefaultDict[str, Any] = body.model_config # type: ignore - openapi_extra = model_config.get("openapi_extra", {}) - if openapi_extra: - openapi_extra_keys = openapi_extra.keys() - if "description" in openapi_extra_keys: - request_body.description = openapi_extra.get("description") - request_body.required = openapi_extra.get("required", True) - if "example" in openapi_extra_keys: - request_body.content["application/json"].example = openapi_extra.get("example") - if "examples" in openapi_extra_keys: - request_body.content["application/json"].examples = openapi_extra.get("examples") - if "encoding" in openapi_extra_keys: - request_body.content["application/json"].encoding = openapi_extra.get("encoding") + if request_body_description: + request_body.description = request_body_description + request_body.required = request_body_required operation.requestBody = request_body if raw: _content = {} for mimetype in raw.mimetypes: - if mimetype.startswith("application/json"): - _content[mimetype] = MediaType( - schema=Schema(type=DataType.OBJECT) - ) + if is_application_json(mimetype): + _content[mimetype] = MediaType(**{"schema": Schema(type=DataType.OBJECT)}) else: - _content[mimetype] = MediaType( - schema=Schema(type=DataType.STRING) - ) + _content[mimetype] = MediaType(**{"schema": Schema(type=DataType.STRING)}) request_body = RequestBody(content=_content) + if request_body_description: + request_body.description = request_body_description + request_body.required = request_body_required operation.requestBody = request_body if parameters: @@ -615,3 +678,7 @@ def convert_responses_key_to_string(responses: ResponseDict) -> ResponseStrKeyDi def normalize_name(name: str) -> str: return re.sub(r"[^\w.\-]", "_", name) + + +def is_application_json(content_type: str) -> bool: + return "application" in content_type and "json" in content_type diff --git a/flask_openapi3/view.py b/flask_openapi3/view.py index 23711bf9..b34cd9de 100644 --- a/flask_openapi3/view.py +++ b/flask_openapi3/view.py @@ -112,6 +112,8 @@ def doc( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, + request_body_description: Optional[str] = None, + request_body_required: Optional[bool] = True, doc_ui: bool = True ) -> Callable: """ @@ -129,6 +131,8 @@ def doc( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -177,9 +181,10 @@ def decorator(func): # Parse parameters parse_parameters( - func, - components_schemas=self.components_schemas, - operation=operation + func, components_schemas=self.components_schemas, + operation=operation, + request_body_description=request_body_description, + request_body_required=request_body_required ) # Parse response diff --git a/tests/test_api_blueprint.py b/tests/test_api_blueprint.py index 4ca6bfcc..c5db8aee 100644 --- a/tests/test_api_blueprint.py +++ b/tests/test_api_blueprint.py @@ -7,7 +7,7 @@ import pytest from pydantic import BaseModel, Field -from flask_openapi3 import APIBlueprint, OpenAPI +from flask_openapi3 import APIBlueprint, OpenAPI, Server, ExternalDocumentation from flask_openapi3 import Tag, Info info = Info(title='book API', version='1.0.0') @@ -82,7 +82,17 @@ def update_book1(path: BookPath, body: BookBody): return {"code": 0, "message": "ok"} -@api.patch('/v2/book/') +@api.patch( + '/v2/book/', + servers=[Server( + url="http://127.0.0.1:5000", + variables=None + )], + external_docs=ExternalDocumentation( + url="https://www.openapis.org/", + description="Something great got better, get excited!"), + deprecated=True +) def update_book1_v2(path: BookPath, body: BookBody): assert path.bid == 1 assert body.age == 3 diff --git a/tests/test_api_view.py b/tests/test_api_view.py index 00b8a638..2f7e3d63 100644 --- a/tests/test_api_view.py +++ b/tests/test_api_view.py @@ -7,7 +7,7 @@ import pytest from pydantic import BaseModel, Field -from flask_openapi3 import APIView +from flask_openapi3 import APIView, Server, ExternalDocumentation from flask_openapi3 import OpenAPI, Tag, Info info = Info(title='book API', version='1.0.0') @@ -73,7 +73,17 @@ def put(self, path: BookPath): print(path) return "put" - @api_view.doc(summary="delete book", deprecated=True) + @api_view.doc( + summary="delete book", + servers=[Server( + url="http://127.0.0.1:5000", + variables=None + )], + external_docs=ExternalDocumentation( + url="https://www.openapis.org/", + description="Something great got better, get excited!"), + deprecated=True + ) def delete(self, path: BookPath): print(path) return "delete" diff --git a/tests/test_model_config.py b/tests/test_model_config.py index 37fd8c0c..e2273901 100644 --- a/tests/test_model_config.py +++ b/tests/test_model_config.py @@ -42,7 +42,6 @@ class BookBody(BaseModel): model_config = dict( openapi_extra={ - "description": "This is post RequestBody", "example": {"age": 12, "author": "author1"}, "examples": { "example1": { @@ -97,7 +96,7 @@ def api_form(form: UploadFilesForm): print(form) # pragma: no cover -@app.post("/body", responses={"200": MessageResponse}) +@app.post("/body", request_body_description="This is post RequestBody", responses={"200": MessageResponse}) def api_error_json(body: BookBody): print(body) # pragma: no cover diff --git a/tests/test_multi_content_type.py b/tests/test_multi_content_type.py new file mode 100644 index 00000000..cc9be776 --- /dev/null +++ b/tests/test_multi_content_type.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# @Author : llc +# @Time : 2025/1/6 16:37 +from typing import Union + +import pytest +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) +app.config["TESTING"] = True + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.dog+json" + } + } + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.cat+json" + } + } + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/bson" + } + } + + +class ContentTypeModel(BaseModel): + model_config = { + "openapi_extra": { + "content_type": "text/csv" + } + } + + +@app.post("/a", responses={200: Union[DogBody, CatBody, ContentTypeModel, BsonModel]}) +def index_a(body: Union[DogBody, CatBody, ContentTypeModel, BsonModel]): + """ + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} + + +@app.post("/b", responses={200: Union[ContentTypeModel, BsonModel]}) +def index_b(body: Union[ContentTypeModel, BsonModel]): + """ + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} + + +@pytest.fixture +def client(): + client = app.test_client() + + return client + + +def test_openapi(client): + resp = client.get("/openapi/openapi.json") + assert resp.status_code == 200 + + resp = client.post("/a", json={"a": 1, "b": "2"}) + assert resp.status_code == 200 + + resp = client.post("/a", data="a,b,c\n1,2,3", headers={"Content-Type": "text/csv"}) + assert resp.status_code == 200 diff --git a/tests/test_openapi.py b/tests/test_openapi.py index bcdeaf72..b30be10e 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -17,8 +17,11 @@ class BaseResponse(BaseModel): """Base description""" test: int - model_config = dict( - openapi_extra={ + @test_app.get( + "/test", + responses={ + "201": { + "model": BaseResponse, "description": "Custom description", "headers": { "location": { @@ -26,20 +29,14 @@ class BaseResponse(BaseModel): "schema": {"type": "string"} } }, - "content": { - "text/plain": { - "schema": {"type": "string"} - } - }, "links": { "dummy": { "description": "dummy link" } } } - ) - - @test_app.get("/test", responses={"201": BaseResponse}) + } + ) def endpoint_test(): return b"", 201 # pragma: no cover @@ -58,10 +55,6 @@ def endpoint_test(): # This content is coming from responses "application/json": { "schema": {"$ref": "#/components/schemas/BaseResponse"} - }, - # While this one comes from responses - "text/plain": { - "schema": {"type": "string"} } }, "links": { @@ -561,16 +554,16 @@ def endpoint_test(body: TupleModel): assert schema == {'$ref': '#/components/schemas/TupleModel'} components = test_app.api_doc["components"]["schemas"] assert components["TupleModel"] == {'properties': {'my_tuple': {'maxItems': 2, - 'minItems': 2, - 'prefixItems': [{'enum': ['a', 'b'], - 'type': 'string'}, - {'enum': ['c', 'd'], - 'type': 'string'}], - 'title': 'My Tuple', - 'type': 'array'}}, - 'required': ['my_tuple'], - 'title': 'TupleModel', - 'type': 'object'} + 'minItems': 2, + 'prefixItems': [{'enum': ['a', 'b'], + 'type': 'string'}, + {'enum': ['c', 'd'], + 'type': 'string'}], + 'title': 'My Tuple', + 'type': 'array'}}, + 'required': ['my_tuple'], + 'title': 'TupleModel', + 'type': 'object'} def test_schema_bigint(request): diff --git a/tests/test_restapi.py b/tests/test_restapi.py index 4a04aafc..21140e92 100644 --- a/tests/test_restapi.py +++ b/tests/test_restapi.py @@ -11,7 +11,7 @@ from flask import Response from pydantic import BaseModel, RootModel, Field -from flask_openapi3 import ExternalDocumentation +from flask_openapi3 import ExternalDocumentation, Server from flask_openapi3 import Info, Tag from flask_openapi3 import OpenAPI @@ -50,6 +50,8 @@ def get_operation_id_for_path_callback(*, name: str, path: str, method: str) -> class BookQuery(BaseModel): age: Optional[int] = Field(None, description='Age') + author: str + none: Optional[None] = None class BookBody(BaseModel): @@ -104,8 +106,13 @@ def client(): external_docs=ExternalDocumentation( url="https://www.openapis.org/", description="Something great got better, get excited!"), + servers=[Server( + url="http://127.0.0.1:5000", + variables=None + )], responses={"200": BookResponse}, - security=security + security=security, + deprecated=True, ) def get_book(path: BookPath): """Get a book @@ -117,7 +124,7 @@ def get_book(path: BookPath): @app.get('/book', tags=[book_tag], responses={"200": BookListResponseV1}) -def get_books(query: BookBody): +def get_books(query: BookQuery): """get books to get all books """ diff --git a/tests/test_server.py b/tests/test_server.py index 3e9b1830..6c8ac800 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,7 +3,7 @@ # @Time : 2024/11/10 12:17 from pydantic import ValidationError -from flask_openapi3 import Server, ServerVariable +from flask_openapi3 import OpenAPI, Server, ServerVariable, ExternalDocumentation def test_server_variable(): @@ -11,33 +11,37 @@ def test_server_variable(): url="http://127.0.0.1:5000", variables=None ) + error = 0 try: variables = {"one": ServerVariable(default="one", enum=[])} - Server( - url="http://127.0.0.1:5000", - variables=variables - ) - error = 0 except ValidationError: error = 1 assert error == 1 - try: - variables = {"one": ServerVariable(default="one")} - Server( - url="http://127.0.0.1:5000", - variables=variables - ) - error = 0 - except ValidationError: - error = 1 + variables = {"one": ServerVariable(default="one")} + Server( + url="http://127.0.0.1:5000", + variables=variables + ) + error = 0 assert error == 0 - try: - variables = {"one": ServerVariable(default="one", enum=["one", "two"])} - Server( - url="http://127.0.0.1:5000", - variables=variables - ) - error = 0 - except ValidationError: - error = 1 + variables = {"one": ServerVariable(default="one", enum=["one", "two"])} + Server( + url="http://127.0.0.1:5000", + variables=variables + ) + error = 0 assert error == 0 + + app = OpenAPI( + __name__, + servers=[Server( + url="http://127.0.0.1:5000", + variables=None + )], + external_docs=ExternalDocumentation( + url="https://www.openapis.org/", + description="Something great got better, get excited!") + ) + + assert "servers" in app.api_doc + assert "externalDocs" in app.api_doc