Skip to content

Commit a17b409

Browse files
committed
support dict as Schema
1 parent 5ac7296 commit a17b409

File tree

4 files changed

+78
-5
lines changed

4 files changed

+78
-5
lines changed

flask_smorest/arguments.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import wraps
55
import http
66

7+
import marshmallow as ma
78
from webargs.flaskparser import FlaskParser
89

910
from .utils import deepupdate
@@ -28,8 +29,8 @@ def arguments(
2829
):
2930
"""Decorator specifying the schema used to deserialize parameters
3031
31-
:param type|Schema schema: Marshmallow ``Schema`` class or instance
32-
used to deserialize and validate the argument.
32+
:param type|Schema|dict schema: Marshmallow ``Schema`` class or instance
33+
or dict used to deserialize and validate the argument.
3334
:param str location: Location of the argument.
3435
:param str content_type: Content type of the argument.
3536
Should only be used in conjunction with ``json``, ``form`` or
@@ -56,6 +57,8 @@ def arguments(
5657
5758
See :doc:`Arguments <arguments>`.
5859
"""
60+
if isinstance(schema, dict):
61+
schema = ma.Schema.from_dict(schema)
5962
# At this stage, put schema instance in doc dictionary. Il will be
6063
# replaced later on by $ref or json.
6164
parameters = {

flask_smorest/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections import abc
44

5+
import marshmallow as ma
56
from werkzeug.datastructures import Headers
67
from flask import g
78
from apispec.utils import trim_docstring, dedent
@@ -31,9 +32,14 @@ def remove_none(mapping):
3132
def resolve_schema_instance(schema):
3233
"""Return schema instance for given schema (instance or class).
3334
34-
:param type|Schema schema: marshmallow.Schema instance or class
35+
:param type|Schema|dict schema: marshmallow.Schema instance or class or dict
3536
:return: schema instance of given schema
3637
"""
38+
39+
# this dict may be used to document a file response, no a schema dict
40+
if isinstance(schema, dict) and all([isinstance(v, (type, ma.fields.Field)) for v in schema.values()]):
41+
schema = ma.Schema.from_dict(schema)
42+
3743
return schema() if isinstance(schema, type) else schema
3844

3945

tests/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class ClientErrorSchema(ma.Schema):
7878
error_id = ma.fields.Str()
7979
text = ma.fields.Str()
8080

81-
return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema"))(
82-
DocSchema, QueryArgsSchema, ClientErrorSchema
81+
DictSchema = {
82+
"item_id": ma.fields.Int(dump_only=True),
83+
"field": ma.fields.Int(attribute="db_field"),
84+
}
85+
86+
return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema", "DictSchema"))(
87+
DocSchema, QueryArgsSchema, ClientErrorSchema, DictSchema
8388
)

tests/test_blueprint.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,65 @@ def func(document, query_args):
307307
"query_args": {"arg1": "test"},
308308
}
309309

310+
@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
311+
def test_blueprint_dict_argument_schema(self, app, schemas, openapi_version):
312+
app.config["OPENAPI_VERSION"] = openapi_version
313+
api = Api(app)
314+
blp = Blueprint("test", __name__, url_prefix="/test")
315+
client = app.test_client()
316+
317+
@blp.route("/", methods=("POST",))
318+
@blp.arguments(schemas.DictSchema)
319+
def func(document):
320+
return {"document": document}
321+
322+
api.register_blueprint(blp)
323+
spec = api.spec.to_dict()
324+
325+
# Check parameters are documented
326+
if openapi_version == "2.0":
327+
parameters = spec["paths"]["/test/"]["post"]["parameters"]
328+
assert len(parameters) == 1
329+
assert parameters[0]["in"] == "body"
330+
assert "schema" in parameters[0]
331+
else:
332+
assert (
333+
"schema"
334+
in spec["paths"]["/test/"]["post"]["requestBody"]["content"][
335+
"application/json"
336+
]
337+
)
338+
339+
# Check parameters are passed as arguments to view function
340+
item_data = {"field": 12}
341+
response = client.post(
342+
"/test/",
343+
data=json.dumps(item_data),
344+
content_type="application/json",
345+
)
346+
assert response.status_code == 200
347+
assert response.json == {
348+
"document": {"db_field": 12},
349+
}
350+
351+
@pytest.mark.parametrize("openapi_version", ["2.0", "3.0.2"])
352+
def test_blueprint_dict_response_schema(self, app, schemas, openapi_version):
353+
"""Check alt_response passes response transparently"""
354+
app.config["OPENAPI_VERSION"] = openapi_version
355+
api = Api(app)
356+
blp = Blueprint("test", "test", url_prefix="/test")
357+
client = app.test_client()
358+
359+
@blp.route("/")
360+
@blp.response(200, schema=schemas.DictSchema)
361+
def func():
362+
return {"item_id": 12}
363+
364+
api.register_blueprint(blp)
365+
366+
resp = client.get("/test/")
367+
assert resp.json == {"item_id": 12}
368+
310369
@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
311370
def test_blueprint_arguments_files_multipart(self, app, schemas, openapi_version):
312371
app.config["OPENAPI_VERSION"] = openapi_version

0 commit comments

Comments
 (0)