Skip to content

Make validate_request pass on existing kwargs but remove those part of path #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions docs/Usage/Request.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,17 @@ Sometimes you want to delay the verification request parameters, such as after l
from flask_openapi3 import validate_request


def login_required(func):
@wraps(func)
def wrapper(*args, **kwargs):
print("login_required ...")
return func(*args, **kwargs)
def login_required():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not request.headers.get("Authorization"):
return {"error": "Unauthorized"}, 401
return func(*args, **kwargs)

return wrapper
return wrapper

return decorator


@app.get("/book")
Expand All @@ -127,6 +131,42 @@ def get_book(query: BookQuery):
...
```

### Custom kwargs are maintained

When your 'auth decorator' injects custom kwargs, these will be passed on to the final function for you to use.

Any kwargs which are part of the 'path' will have been consumed at this point and can only be referenced using the `path`.

So avoid using kwarg-names which overlap with the path.

```python
from flask_openapi3 import validate_request
from functools import wraps


def login_required():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not request.headers.get("Authorization"):
return {"error": "Unauthorized"}, 401
kwargs["client_id"] = "client1234565"
return func(*args, **kwargs)

return wrapper

return decorator



@app.get("/book")
@login_required()
@validate_request()
def get_book(query: BookQuery, client_id:str = None):
print(f"Current user identified as {client_id}")
...
```

## Request model

First, you need to define a [pydantic](https://github.com/pydantic/pydantic) model:
Expand Down
34 changes: 20 additions & 14 deletions flask_openapi3/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import json
from functools import wraps
from json import JSONDecodeError
from typing import Any, Type, Optional
from typing import Any, Optional, Type

from flask import request, current_app, abort
from pydantic import ValidationError, BaseModel
from flask import abort, current_app, request
from pydantic import BaseModel, ValidationError
from pydantic.fields import FieldInfo
from werkzeug.datastructures.structures import MultiDict

Expand Down Expand Up @@ -78,7 +78,11 @@ def _validate_cookie(cookie: Type[BaseModel], func_kwargs: dict):


def _validate_path(path: Type[BaseModel], path_kwargs: dict, func_kwargs: dict):
func_kwargs["path"] = path.model_validate(obj=path_kwargs)
path_obj = path.model_validate(obj=path_kwargs)
func_kwargs["path"] = path_obj
# Consume path parameters to prevent from being passed to the function
for field_name, _ in path_obj:
path_kwargs.pop(field_name, None)


def _validate_query(query: Type[BaseModel], func_kwargs: dict):
Expand Down Expand Up @@ -151,14 +155,14 @@ def _validate_body(body: Type[BaseModel], func_kwargs: dict):


def _validate_request(
header: Optional[Type[BaseModel]] = None,
cookie: Optional[Type[BaseModel]] = None,
path: Optional[Type[BaseModel]] = None,
query: Optional[Type[BaseModel]] = None,
form: Optional[Type[BaseModel]] = None,
body: Optional[Type[BaseModel]] = None,
raw: Optional[Type[BaseModel]] = None,
path_kwargs: Optional[dict[Any, Any]] = None
header: Optional[Type[BaseModel]] = None,
cookie: Optional[Type[BaseModel]] = None,
path: Optional[Type[BaseModel]] = None,
query: Optional[Type[BaseModel]] = None,
form: Optional[Type[BaseModel]] = None,
body: Optional[Type[BaseModel]] = None,
raw: Optional[Type[BaseModel]] = None,
path_kwargs: Optional[dict[Any, Any]] = None,
) -> dict:
"""
Validate requests and responses.
Expand Down Expand Up @@ -212,7 +216,6 @@ def validate_request():
"""

def decorator(func):

setattr(func, "__delay_validate_request__", True)

is_coroutine_function = inspect.iscoroutinefunction(func)
Expand All @@ -223,6 +226,8 @@ def decorator(func):
async def wrapper(*args, **kwargs):
header, cookie, path, query, form, body, raw = parse_parameters(func)
func_kwargs = _validate_request(header, cookie, path, query, form, body, raw, path_kwargs=kwargs)
# Update func_kwargs with any additional keyword arguments passed from other decorators or calls.
func_kwargs.update(kwargs)

return await func(*args, **func_kwargs)

Expand All @@ -233,7 +238,8 @@ async def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs):
header, cookie, path, query, form, body, raw = parse_parameters(func)
func_kwargs = _validate_request(header, cookie, path, query, form, body, raw, path_kwargs=kwargs)

# Update func_kwargs with any additional keyword arguments passed from other decorators or calls.
func_kwargs.update(kwargs)
return func(*args, **func_kwargs)

return wrapper
Expand Down
109 changes: 109 additions & 0 deletions tests/test_validate_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from functools import wraps
from typing import Optional

import pytest
from flask import request
from pydantic import BaseModel, Field

from flask_openapi3 import APIView, Info, OpenAPI, Tag
from flask_openapi3.request import validate_request


class BookNamePath(BaseModel):
name: str


class BookBody(BaseModel):
age: Optional[int] = Field(..., ge=2, le=4, description="Age")
author: str = Field(None, min_length=2, max_length=4, description="Author")
name: str


def login_required():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not request.headers.get("Authorization"):
return {"error": "Unauthorized"}, 401
kwargs["client_id"] = "client1234565"
return func(*args, **kwargs)

return wrapper

return decorator


@pytest.fixture
def app():
app = OpenAPI(__name__)
app.config["TESTING"] = True

info = Info(title="book API", version="1.0.0")
jwt = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
security_schemes = {"jwt": jwt}

app = OpenAPI(__name__, info=info, security_schemes=security_schemes)
app.config["TESTING"] = True
security = [{"jwt": []}]

api_view = APIView(url_prefix="/v1/books", view_tags=[Tag(name="book")], view_security=security)

@api_view.route("")
class BookListAPIView:
@api_view.doc(summary="get book list", responses={204: None}, doc_ui=False)
@login_required()
@validate_request()
def get(self, client_id: str):
return {"books": ["book1", "book2"], "client_id": client_id}

@api_view.doc(summary="create book")
@login_required()
@validate_request()
def post(self, body: BookBody, client_id):
"""description for a created book"""
return body.model_dump_json()

@api_view.route("/<name>")
class BookNameAPIView:
@api_view.doc(summary="get book by name")
@login_required()
@validate_request()
def get(self, path: BookNamePath, client_id):
return {"name": path.name, "client_id": client_id}

app.register_api_view(api_view)
return app


@pytest.fixture
def client(app):
client = app.test_client()

return client


def test_get_book_list_happy(app, client):
response = client.get("/v1/books", headers={"Authorization": "Bearer sometoken"})
assert response.status_code == 200
assert response.json == {"books": ["book1", "book2"], "client_id": "client1234565"}


def test_get_book_list_not_auth(app, client):
response = client.get("/v1/books", headers={"Nope": "Bearer sometoken"})
assert response.status_code == 401
assert response.json == {"error": "Unauthorized"}


def test_create_book_happy(app, client):
response = client.post(
"/v1/books",
json={"age": 3, "author": "John", "name": "some_book_name"},
headers={"Authorization": "Bearer sometoken"},
)
assert response.status_code == 200


def test_get_book_detail_happy(app, client):
response = client.get("/v1/books/some_book_name", headers={"Authorization": "Bearer sometoken"})
assert response.status_code == 200
assert response.json == {"name": "some_book_name", "client_id": "client1234565"}