Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions osmchadjango/users/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.contrib.auth import get_user_model
from django.conf import settings
from urllib.parse import urlparse

from rest_framework.authtoken.models import Token
from rest_framework.generics import (
Expand Down Expand Up @@ -77,14 +78,39 @@ class SocialAuthAPIView(GenericAPIView):
base_oauth2_url = "{}/oauth2".format(settings.OSM_SERVER_URL)
token_url = "{}/token".format(base_oauth2_url)
auth_url = "{}/authorize".format(base_oauth2_url)
consumer = OAuth2Session(
client_id=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_KEY,
scope=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SCOPE,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)

def get_access_token(self, code):
return self.consumer.fetch_token(
def get_redirect_uri(self, request):
"""
Get the redirect URI for the OAuth flow.

Normally, after login we ask the OAuth provider to redirect back to
osmcha.org (or whatever domain you're running OSMCha at). But to allow
developers to run the frontend locally while using the production server
as the backend, we have a special case for when the HTTP Origin header
has a hostname of localhost or 127.0.0.1. In those cases we redirect
back to that hostname.
"""
origin = request.META.get('HTTP_ORIGIN')
if origin:
try:
url = urlparse(origin)
if url.hostname in {'127.0.0.1', 'localhost'}:
return f"{origin}/authorized"
except (ValueError, AttributeError):
pass
return settings.OAUTH_REDIRECT_URI

def get_oauth_consumer(self, request, state=None):
return OAuth2Session(
client_id=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_KEY,
scope=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SCOPE,
redirect_uri=self.get_redirect_uri(request),
state=state,
)

def get_access_token(self, request, code):
consumer = self.get_oauth_consumer(request)
return consumer.fetch_token(
token_url=self.token_url,
code=code,
client_secret=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SECRET,
Expand All @@ -95,21 +121,22 @@ def get_user_token(self, request, access_token, *args, **kwargs):
backend = load_backend(
strategy=load_strategy(request),
name="openstreetmap-oauth2",
redirect_uri=settings.OAUTH_REDIRECT_URI,
redirect_uri=self.get_redirect_uri(request),
)
user = backend.do_auth(access_token, *args, **kwargs)
token, created = Token.objects.get_or_create(user=user)
return {'token': token.key}

def post(self, request, *args, **kwargs):
if 'code' not in request.data.keys() or not request.data['code']:
login_url, state = self.consumer.authorization_url(self.auth_url)
consumer = self.get_oauth_consumer(request)
login_url, state = consumer.authorization_url(self.auth_url)
return Response({"auth_url": login_url, "state": state})
else:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
access_token = self.get_access_token(
request.data['code'],
request, request.data['code'],
).get('access_token')
return Response(self.get_user_token(request, access_token))

Expand Down