Skip to content

Commit

Permalink
Add custom claims and extra verifications
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Nov 8, 2018
1 parent f697961 commit 5477b9b
Show file tree
Hide file tree
Showing 38 changed files with 742 additions and 200 deletions.
3 changes: 3 additions & 0 deletions example/custom_authentication_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


class MyAuthentication(Authentication):

async def authenticate(self, request, *args, **kwargs):
username = request.json.get("username", None)
password = request.json.get("password", None)
Expand All @@ -31,6 +32,7 @@ async def retrieve_user(self, request, payload, *args, **kwargs):
if payload:
user_id = payload.get("user_id", None)
return {"user_id": user_id}

else:
return None

Expand All @@ -54,6 +56,7 @@ async def protected_request(request):
return json({"protected": True})

# this route is for demonstration only

@app.route("/cache")
@sanicjwt.protected()
async def protected_cache(request):
Expand Down
72 changes: 72 additions & 0 deletions example/custom_claims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from sanic import Sanic
from sanic.response import json
from sanic_jwt import exceptions
from sanic_jwt import Initialize
from sanic_jwt import protected
from sanic_jwt import Claim


class User:

def __init__(self, id, username, password):
self.user_id = id
self.username = username
self.password = password

def __repr__(self):
return "User(id='{}')".format(self.user_id)

def to_dict(self):
return {"user_id": self.user_id, "username": self.username}


users = [User(1, "user1", "abcxyz"), User(2, "user2", "abcxyz")]

username_table = {u.username: u for u in users}
userid_table = {u.user_id: u for u in users}


async def authenticate(request, *args, **kwargs):
username = request.json.get("username", None)
password = request.json.get("password", None)

if not username or not password:
raise exceptions.AuthenticationFailed("Missing username or password.")

user = username_table.get(username, None)
if user is None:
raise exceptions.AuthenticationFailed("User not found.")

if password != user.password:
raise exceptions.AuthenticationFailed("Password is incorrect.")

return user


class User2Claim(Claim):
key = "user_id"

def setup(self, payload, user):
payload[self.key] = user.get("user_id")
return payload

def verify(self, value):
return value == 2


custom_claims = [User2Claim]
app = Sanic()
sanicjwt = Initialize(
app, authenticate=authenticate, custom_claims=custom_claims, debug=True
)


@app.route("/protected")
@protected()
async def protected(request):
print(request.app.auth._custom_claims)
return json({"protected": True})


if __name__ == "__main__":
app.run(host="127.0.0.1", port=8888, auto_reload=True)
64 changes: 64 additions & 0 deletions example/extra_verifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from sanic import Sanic
from sanic.response import json
from sanic_jwt import exceptions
from sanic_jwt import Initialize
from sanic_jwt import protected


class User:

def __init__(self, id, username, password):
self.user_id = id
self.username = username
self.password = password

def __repr__(self):
return "User(id='{}')".format(self.user_id)

def to_dict(self):
return {"user_id": self.user_id, "username": self.username}


users = [User(1, "user1", "abcxyz"), User(2, "user2", "abcxyz")]

username_table = {u.username: u for u in users}
userid_table = {u.user_id: u for u in users}


async def authenticate(request, *args, **kwargs):
username = request.json.get("username", None)
password = request.json.get("password", None)

if not username or not password:
raise exceptions.AuthenticationFailed("Missing username or password.")

user = username_table.get(username, None)
if user is None:
raise exceptions.AuthenticationFailed("User not found.")

if password != user.password:
raise exceptions.AuthenticationFailed("Password is incorrect.")

return user


def user2(payload):
return payload.get("user_id") == 2


extra_verifications = [user2]

app = Sanic()
Initialize(
app, authenticate=authenticate, extra_verifications=extra_verifications
)


@app.route("/protected")
@protected()
async def protected(request):
return json({"protected": True})


if __name__ == "__main__":
app.run(host="127.0.0.1", port=8888, auto_reload=True)
5 changes: 3 additions & 2 deletions example/inject_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ async def authenticate(request, *args, **kwargs):


app = Sanic()
sanic_jwt = Initialize(app, authenticate=authenticate,
retrieve_user=retrieve_user)
sanic_jwt = Initialize(
app, authenticate=authenticate, retrieve_user=retrieve_user
)


@app.route("/hello")
Expand Down
11 changes: 7 additions & 4 deletions sanic_jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import logging

from .authentication import Authentication
from .claim import Claim
from .configuration import Configuration
from .decorators import protected, scoped
from .decorators import protected, scoped, inject_user
from .endpoints import BaseEndpoint
from .initialization import Initialize, initialize
from .responses import Responses
Expand All @@ -16,11 +17,13 @@

__all__ = [
"Authentication",
"Initialize",
"Configuration",
"Responses",
"BaseEndpoint",
"Claim",
"Configuration",
"initialize",
"Initialize",
"inject_user",
"protected",
"Responses",
"scoped",
]
48 changes: 43 additions & 5 deletions sanic_jwt/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import warnings
import jwt
from .exceptions import SanicJWTException
from .exceptions import InvalidVerification
from .exceptions import InvalidVerificationError
from .exceptions import InvalidCustomClaimError

from . import exceptions, utils

Expand All @@ -16,8 +19,10 @@ class BaseAuthentication:
def __init__(self, app, config):
self.app = app
self.claims = ["exp"]
self._extra_verifications = None
self.config = config
self._reasons = []
self._custom_claims = set()

async def _get_user_id(self, user, *, asdict=False):
"""
Expand All @@ -42,9 +47,10 @@ async def _get_user_id(self, user, *, asdict=False):
async def build_payload(self, user, *args, **kwargs):
return await self._get_user_id(user, asdict=True)

async def add_claims(self, payload, *args, **kwargs):
async def add_claims(self, payload, user, *args, **kwargs):
"""
Injects standard claims into the payload for: exp, iss, iat, nbf, aud.
And, custom claims, if they exist
"""
delta = timedelta(seconds=self.config.expiration_delta())
exp = datetime.utcnow() + delta
Expand All @@ -62,6 +68,13 @@ async def add_claims(self, payload, *args, **kwargs):

payload.update(additional)

if self._custom_claims:
custom_claims = {
x.get_key(): await utils.call(x.setup, payload, user)
for x in self._custom_claims
}
payload.update(custom_claims)

return payload

async def extend_payload(self, payload, user=None, *args, **kwargs):
Expand Down Expand Up @@ -173,7 +186,7 @@ async def _get_payload(self, user):
):
raise exceptions.InvalidPayload

payload = await utils.call(self.add_claims, payload)
payload = await utils.call(self.add_claims, payload, user)

extend_payload_args = inspect.getfullargspec(self.extend_payload)
args = [payload]
Expand All @@ -187,7 +200,8 @@ async def _get_payload(self, user):
scopes = [scopes]
payload[self.config.scopes_name()] = scopes

missing = [x for x in self.claims if x not in payload]
claims = self.claims + [x.get_key() for x in self._custom_claims]
missing = [x for x in claims if x not in payload]
if missing:
logger.debug("")
raise exceptions.MissingRegisteredClaim(missing=missing)
Expand Down Expand Up @@ -329,16 +343,24 @@ def _verify(
if token:
try:
payload = self._decode(token, verify=verify)

if verify:
if self._extra_verifications:
self._verify_extras(payload)
if self._custom_claims:
self._verify_custom_claims(payload)
except (
jwt.exceptions.ExpiredSignatureError,
jwt.exceptions.InvalidIssuerError,
jwt.exceptions.ImmatureSignatureError,
jwt.exceptions.InvalidIssuedAtError,
jwt.exceptions.InvalidAudienceError,
InvalidVerificationError,
InvalidCustomClaimError,
) as e:
# Make sure that the reasons all end with '.' for consistency
reason = [
x if x.endswith('.') else '{}.'.format(x)
x if x.endswith(".") else "{}.".format(x)
for x in list(e.args)
]
payload = None
Expand All @@ -348,7 +370,7 @@ def _verify(
self._reasons = e.args
# Make sure that the reasons all end with '.' for consistency
reason = [
x if x.endswith('.') else '{}.'.format(x)
x if x.endswith(".") else "{}.".format(x)
for x in list(e.args)
] if self.config.debug() else "Auth required."
logger.debug(e.args)
Expand All @@ -365,6 +387,22 @@ def _verify(

return is_valid, status, reason

def _verify_extras(self, payload):
for verification in self._extra_verifications:
if not callable(verification):
raise InvalidVerification()

verified = verification(payload)
if not isinstance(verified, bool):
raise InvalidVerification()

if verified is False:
raise InvalidVerificationError()

def _verify_custom_claims(self, payload):
for claim in self._custom_claims:
claim._verify(payload)

def extract_payload(self, request, verify=True, *args, **kwargs):
"""
Extract a payload from a request object.
Expand Down
2 changes: 1 addition & 1 deletion sanic_jwt/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .exceptions import LoopNotRunning


def _get_current_task(loop):
def _get_current_task(loop): # noqa
if sys.version_info[:2] < (3, 7): # to avoid deprecation warning
return asyncio.Task.current_task(loop=loop)
else:
Expand Down
25 changes: 25 additions & 0 deletions sanic_jwt/claim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sanic_jwt import exceptions


class Claim:
@classmethod
def _register(cls, sanicjwt):
required = ('key', 'setup', 'verify')
instance = cls()
if any(not hasattr(instance, x) for x in required):
raise AttributeError
sanicjwt.instance.auth._custom_claims.add(instance)

def get_key(self):
return self.key

def _verify(self, payload):
key = self.get_key()
value = payload.get(key)
valid_claim = self.verify(value)
if not isinstance(valid_claim, bool):
raise exceptions.InvalidCustomClaim()

if valid_claim is False:
message = "Invalid claim: {}".format(key)
raise exceptions.InvalidCustomClaimError(message=message)
Loading

0 comments on commit 5477b9b

Please sign in to comment.