From a797afcd7683967c831e2ec88e9e658faa03fa09 Mon Sep 17 00:00:00 2001 From: Anton Zhukharev Date: Thu, 26 Dec 2024 11:51:53 +0300 Subject: [PATCH 1/2] use flask machinery for json --- flask_restx/representations.py | 6 +----- tests/legacy/test_api_legacy.py | 10 +++------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/flask_restx/representations.py b/flask_restx/representations.py index 123b7864..8fe04da0 100644 --- a/flask_restx/representations.py +++ b/flask_restx/representations.py @@ -1,9 +1,5 @@ -try: - from ujson import dumps -except ImportError: - from json import dumps - from flask import make_response, current_app +from flask.json import dumps def output_json(data, code, headers=None): diff --git a/tests/legacy/test_api_legacy.py b/tests/legacy/test_api_legacy.py index b15a6027..c132b441 100644 --- a/tests/legacy/test_api_legacy.py +++ b/tests/legacy/test_api_legacy.py @@ -334,14 +334,10 @@ def get(self): assert data == expected def test_use_custom_jsonencoder(self, app, client): - class CabageEncoder(JSONEncoder): - def default(self, obj): - return "cabbage" + def default(obj): + return "cabbage" - class TestConfig(object): - RESTX_JSON = {"cls": CabageEncoder} - - app.config.from_object(TestConfig) + app.json.default = default api = restx.Api(app) class Cabbage(restx.Resource): From 2be088c960aa3c5637994059af19c3859761315d Mon Sep 17 00:00:00 2001 From: Anton Zhukharev Date: Thu, 26 Dec 2024 19:09:16 +0300 Subject: [PATCH 2/2] test custom json setting --- tests/test_custom_json.py | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/test_custom_json.py diff --git a/tests/test_custom_json.py b/tests/test_custom_json.py new file mode 100644 index 00000000..1f8e278d --- /dev/null +++ b/tests/test_custom_json.py @@ -0,0 +1,45 @@ +import json +import pytest + +from flask.json.provider import JSONProvider +from flask_restx import Api, Resource + +from datetime import datetime + + +class CustomJSONTest(object): + @pytest.mark.parametrize("provider", ["json", "ujson", "sdjson"]) + def test_custom_json(self, app, client, provider): + provmod = pytest.importorskip(provider) + class CustomJSONProvider(JSONProvider): + def dumps(self, obj, **kwargs): + extra = {"serializer": provmod.__name__} + return provmod.dumps( + obj | extra, + default=CustomJSONProvider._default, + **kwargs + ) + + def loads(self, s, **kwargs): + extra = {"deserializer": provmod.__name__} + return provmod.loads(s, **kwargs) | extra + + @staticmethod + def _default(obj): + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + app.json_provider_class = CustomJSONProvider + app.json = app.json_provider_class(app) + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + def post(self): + return api.payload, 200 + + resp = client.post("/test", json={"name": "tester"}) + assert resp.status_code == 200 + assert json.loads(resp.data.decode("utf-8"))["serializer"] == provmod.__name__ + assert json.loads(resp.data.decode("utf-8"))["deserializer"] == provmod.__name__