diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3ce22bbc..b519cfa0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,7 @@ Before submitting pr, you need to complete the following steps: 1. Install requirements ```bash - pip install -U flask pydantic + pip install -U flask pydantic pyyaml pytest ruff mypy ``` 2. Running the tests diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index 5c1c91e8..56eb6046 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -306,8 +306,10 @@ def register_api(self, api: APIBlueprint, **options: Any) -> None: url_prefix = options.get("url_prefix") if url_prefix and api.url_prefix and url_prefix != api.url_prefix: api.paths = {url_prefix + k.removeprefix(api.url_prefix): v for k, v in api.paths.items()} + api.url_prefix = url_prefix elif url_prefix and not api.url_prefix: api.paths = {url_prefix.rstrip("/") + "/" + k.lstrip("/"): v for k, v in api.paths.items()} + api.url_prefix = url_prefix self.paths.update(**api.paths) # Update component schemas with the APIBlueprint's component schemas @@ -345,8 +347,10 @@ def register_api_view( # Update paths with the APIView's paths if url_prefix and api_view.url_prefix and url_prefix != api_view.url_prefix: api_view.paths = {url_prefix + k.removeprefix(api_view.url_prefix): v for k, v in api_view.paths.items()} + api_view.url_prefix = url_prefix elif url_prefix and not api_view.url_prefix: api_view.paths = {url_prefix.rstrip("/") + "/" + k.lstrip("/"): v for k, v in api_view.paths.items()} + api_view.url_prefix = url_prefix self.paths.update(**api_view.paths) # Update component schemas with the APIView's component schemas diff --git a/tests/test_api_blueprint.py b/tests/test_api_blueprint.py index 22397251..0cb1bf6c 100644 --- a/tests/test_api_blueprint.py +++ b/tests/test_api_blueprint.py @@ -135,3 +135,38 @@ def test_patch(client): def test_delete(client): resp = client.delete("/api/book/1") assert resp.status_code == 200 + + +# Create a second blueprint here to test when `url_prefix` is None +author_api = APIBlueprint( + '/author', + __name__, + abp_tags=[tag], + abp_security=security, + abp_responses={"401": Unauthorized}, +) + + +class AuthorBody(BaseModel): + age: Optional[int] = Field(..., ge=1, le=100, description='Age') + + +@author_api.post('/') +def get_author(body: AuthorBody): + pass + + +def create_app(): + app = OpenAPI(__name__, info=info, security_schemes=security_schemes) + app.register_api(api, url_prefix='/1.0') + app.register_api(author_api, url_prefix='/1.0/author') + + +# Invoke twice to ensure that call is idempotent +create_app() +create_app() + + +def test_blueprint_path_and_prefix(): + assert list(api.paths.keys()) == ['/1.0/book/{bid}', '/1.0/v2/book/{bid}'] + assert list(author_api.paths.keys()) == ['/1.0/author/{aid}'] diff --git a/tests/test_api_view.py b/tests/test_api_view.py index 00b8a638..fd08e206 100644 --- a/tests/test_api_view.py +++ b/tests/test_api_view.py @@ -24,6 +24,7 @@ api_view = APIView(url_prefix="/api/v1/", view_tags=[Tag(name="book")], view_security=security) api_view2 = APIView(doc_ui=False) +api_view_no_url = APIView(view_tags=[Tag(name="book")], view_security=security) class BookPath(BaseModel): @@ -86,6 +87,13 @@ def get(self, path: BookPath): return path.model_dump() +@api_view_no_url.route("/book3") +class BookAPIViewNoUrl: + @api_view_no_url.doc(summary="get book3") + def get(self, path: BookPath): + return path.model_dump() + + app.register_api_view(api_view) app.register_api_view(api_view2) @@ -132,3 +140,19 @@ def test_get(client): def test_delete(client): resp = client.delete("/api/v1/name1/book/1") assert resp.status_code == 200 + + +def create_app(): + app = OpenAPI(__name__, info=info, security_schemes=security_schemes) + app.register_api_view(api_view, url_prefix='/api/1.0') + app.register_api_view(api_view_no_url, url_prefix='/api/1.0') + + +# Invoke twice to ensure that call is idempotent +create_app() +create_app() + + +def test_register_api_view_idempotency(): + assert list(api_view.paths.keys()) == ['/api/1.0/api/v1/{name}/book', '/api/1.0/api/v1/{name}/book/{id}'] + assert list(api_view_no_url.paths.keys()) == ['/api/1.0/book3']