diff --git a/adrf/routers.py b/adrf/routers.py new file mode 100644 index 0000000..388fad8 --- /dev/null +++ b/adrf/routers.py @@ -0,0 +1,42 @@ +from rest_framework.routers import ( + DefaultRouter as DRFDefaultRouter, +) +from rest_framework.routers import ( + SimpleRouter as DRFSimpleRouter, +) + + +class SimpleRouter(DRFSimpleRouter): + sync_to_async_action_map = { + "list": "alist", + "create": "acreate", + "retrieve": "aretrieve", + "update": "aupdate", + "destroy": "adestroy", + "partial_update": "partial_aupdate", + } + + def get_method_map(self, viewset, method_map): + """ + Given a viewset, and a mapping of http methods to actions, + return a new mapping which only includes any mappings that + are actually implemented by the viewset. + + To allow the use of a single router that registers sync and async + viewsets, the actions defined in the routes' method maps are + updated to be the "a"-prefixed names for async viewsets. + """ + bound_methods = {} + if getattr(viewset, "view_is_async", False): + method_map = { + method: self.sync_to_async_action_map.get(action, action) + for method, action in method_map.items() + } + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + +class DefaultRouter(SimpleRouter, DRFDefaultRouter): + pass diff --git a/adrf/viewsets.py b/adrf/viewsets.py index be1c4f7..3f5b8db 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -20,9 +20,9 @@ class ViewSetMixin(DRFViewSetMixin): the binding of HTTP methods to actions on the Resource. For example, to create a concrete view binding the 'GET' and 'POST' methods - to the 'list' and 'create' actions... + to the 'alist' and 'acreate' actions... - view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) + view = MyViewSet.as_view({'get': 'alist', 'post': 'acreate'}) """ @classonlymethod @@ -155,14 +155,13 @@ def view_is_async(cls): """ Checks whether any viewset methods are coroutines. """ - result = [ + return any( asyncio.iscoroutinefunction(function) for name, function in getmembers( cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"] ) if not name.startswith("__") and name not in cls._ASYNC_NON_DISPATCH_METHODS - ] - return any(result) + ) class GenericViewSet(ViewSet, GenericAPIView): diff --git a/tests/test_routers.py b/tests/test_routers.py new file mode 100644 index 0000000..edb7058 --- /dev/null +++ b/tests/test_routers.py @@ -0,0 +1,131 @@ +from asgiref.sync import async_to_sync +from django.contrib.auth.models import User +from django.test import Client, TestCase, override_settings +from rest_framework import status +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.viewsets import ModelViewSet as DRFModelViewSet + +from adrf.routers import SimpleRouter, DefaultRouter +from adrf.serializers import ModelSerializer +from adrf.viewsets import ModelViewSet as AsyncModelViewSet +from tests.test_views import JSON_ERROR, sanitise_json_error + + +class SyncViewSet(DRFModelViewSet): + def list(self, request): + return Response({"method": "GET", "async": False}) + + def create(self, request): + return Response({"method": "POST", "data": request.data, "async": False}) + + def retrieve(self, request, pk): + return Response({"method": "GET", "data": {"pk": pk}, "async": False}) + + def update(self, request, pk): + return Response({"method": "PUT", "data": request.data, "async": False}) + + def partial_update(self, request, pk): + return Response({"method": "PATCH", "data": request.data, "async": False}) + + def destroy(self, request, pk): + return Response({"method": "DELETE", "async": False}) + + +class AsyncViewSet(AsyncModelViewSet): + async def alist(self, request): + return Response({"method": "GET", "async": True}) + + async def acreate(self, request): + return Response({"method": "POST", "data": request.data, "async": True}) + + async def aretrieve(self, request, pk): + return Response({"method": "GET", "data": {"pk": pk}, "async": True}) + + async def aupdate(self, request, pk): + return Response({"method": "PUT", "data": request.data, "async": True}) + + async def partial_aupdate(self, request, pk): + return Response({"method": "PATCH", "data": request.data, "async": True}) + + async def adestroy(self, request, pk): + return Response({"method": "DELETE", "async": True}) + + +router = SimpleRouter() +router.register("sync", SyncViewSet, basename="sync") +router.register("async", AsyncViewSet, basename="async") +urlpatterns = router.urls + + +@override_settings(ROOT_URLCONF="tests.test_routers") +class _RouterIntegrationTests(TestCase): + use_async = None + __test__ = False + + def setUp(self): + self.client = Client() + self.url = "/" + ("async" if self.use_async else "sync") + "/" + self.detail_url = self.url + "1/" + + def test_list(self): + resp = self.client.get(self.url) + assert resp.status_code == 200 + assert resp.data == {"method": "GET", "async": self.use_async} + + def test_create(self): + resp = self.client.post( + self.url, {"foo": "bar"}, content_type="application/json" + ) + assert resp.status_code == 200 + assert resp.data == { + "method": "POST", + "data": {"foo": "bar"}, + "async": self.use_async, + } + + def test_retrieve(self): + resp = self.client.get(self.detail_url) + assert resp.status_code == 200 + assert resp.data == { + "method": "GET", + "data": {"pk": "1"}, + "async": self.use_async, + } + + def test_update(self): + resp = self.client.put( + self.detail_url, {"foo": "bar"}, content_type="application/json" + ) + assert resp.status_code == 200 + assert resp.data == { + "method": "PUT", + "data": {"foo": "bar"}, + "async": self.use_async, + } + + def test_partial_update(self): + resp = self.client.patch( + self.detail_url, {"foo": "bar"}, content_type="application/json" + ) + assert resp.status_code == 200 + assert resp.data == { + "method": "PATCH", + "data": {"foo": "bar"}, + "async": self.use_async, + } + + def test_destroy(self): + resp = self.client.delete(self.detail_url) + assert resp.status_code == 200 + assert resp.data == {"method": "DELETE", "async": self.use_async} + + +class TestSyncRouterIntegrationTests(_RouterIntegrationTests): + use_async = False + __test__ = True + + +class AsyncRouterIntegrationTests(_RouterIntegrationTests): + use_async = True + __test__ = True