diff --git a/osmchadjango/users/views.py b/osmchadjango/users/views.py index f388f19d..29a34b6c 100644 --- a/osmchadjango/users/views.py +++ b/osmchadjango/users/views.py @@ -3,6 +3,7 @@ from django.contrib.auth import get_user_model from django.conf import settings +from urllib.parse import urlparse from rest_framework.authtoken.models import Token from rest_framework.generics import ( @@ -77,14 +78,39 @@ class SocialAuthAPIView(GenericAPIView): base_oauth2_url = "{}/oauth2".format(settings.OSM_SERVER_URL) token_url = "{}/token".format(base_oauth2_url) auth_url = "{}/authorize".format(base_oauth2_url) - consumer = OAuth2Session( - client_id=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_KEY, - scope=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SCOPE, - redirect_uri=settings.OAUTH_REDIRECT_URI, - ) - def get_access_token(self, code): - return self.consumer.fetch_token( + def get_redirect_uri(self, request): + """ + Get the redirect URI for the OAuth flow. + + Normally, after login we ask the OAuth provider to redirect back to + osmcha.org (or whatever domain you're running OSMCha at). But to allow + developers to run the frontend locally while using the production server + as the backend, we have a special case for when the HTTP Origin header + has a hostname of localhost or 127.0.0.1. In those cases we redirect + back to that hostname. + """ + origin = request.META.get('HTTP_ORIGIN') + if origin: + try: + url = urlparse(origin) + if url.hostname in {'127.0.0.1', 'localhost'}: + return f"{origin}/authorized" + except (ValueError, AttributeError): + pass + return settings.OAUTH_REDIRECT_URI + + def get_oauth_consumer(self, request, state=None): + return OAuth2Session( + client_id=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_KEY, + scope=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SCOPE, + redirect_uri=self.get_redirect_uri(request), + state=state, + ) + + def get_access_token(self, request, code): + consumer = self.get_oauth_consumer(request) + return consumer.fetch_token( token_url=self.token_url, code=code, client_secret=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SECRET, @@ -95,7 +121,7 @@ def get_user_token(self, request, access_token, *args, **kwargs): backend = load_backend( strategy=load_strategy(request), name="openstreetmap-oauth2", - redirect_uri=settings.OAUTH_REDIRECT_URI, + redirect_uri=self.get_redirect_uri(request), ) user = backend.do_auth(access_token, *args, **kwargs) token, created = Token.objects.get_or_create(user=user) @@ -103,13 +129,14 @@ def get_user_token(self, request, access_token, *args, **kwargs): def post(self, request, *args, **kwargs): if 'code' not in request.data.keys() or not request.data['code']: - login_url, state = self.consumer.authorization_url(self.auth_url) + consumer = self.get_oauth_consumer(request) + login_url, state = consumer.authorization_url(self.auth_url) return Response({"auth_url": login_url, "state": state}) else: serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) access_token = self.get_access_token( - request.data['code'], + request, request.data['code'], ).get('access_token') return Response(self.get_user_token(request, access_token))