Skip to content

Commit 2d61588

Browse files
committed
jwt_backends: create backend mechanism and add authlib support
1 parent 2f733ce commit 2d61588

16 files changed

+697
-397
lines changed

fastapi_jwt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .jwt import * # noqa: F401, F403
2+
from .jwt_backends import * # noqa: F401, F403

fastapi_jwt/jwt.py

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@
88
from fastapi.responses import Response
99
from fastapi.security import APIKeyCookie, HTTPBearer
1010
from starlette.status import HTTP_401_UNAUTHORIZED
11+
from .jwt_backends import AuthlibJWTBackend, PythonJoseJWTBackend
1112

12-
try:
13-
from jose import jwt
14-
except ImportError: # pragma: nocover
15-
jwt = None # type: ignore[assignment]
13+
14+
DEFAULT_JWT_BACKEND = None
15+
16+
17+
def define_default_jwt_backend(cls):
18+
global DEFAULT_JWT_BACKEND
19+
DEFAULT_JWT_BACKEND = cls
20+
21+
22+
if AuthlibJWTBackend is not None:
23+
define_default_jwt_backend(AuthlibJWTBackend)
24+
elif PythonJoseJWTBackend is not None:
25+
define_default_jwt_backend(PythonJoseJWTBackend)
1626

1727

1828
def utcnow():
@@ -27,6 +37,7 @@ def utcnow():
2737

2838

2939
__all__ = [
40+
"define_default_jwt_backend",
3041
"JwtAuthorizationCredentials",
3142
"JwtAccessBearer",
3243
"JwtAccessCookie",
@@ -72,28 +83,26 @@ def __init__(
7283
secret_key: str,
7384
places: Optional[Set[str]] = None,
7485
auto_error: bool = True,
75-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
86+
algorithm: Optional[str] = None,
7687
access_expires_delta: Optional[timedelta] = None,
7788
refresh_expires_delta: Optional[timedelta] = None,
7889
):
79-
assert jwt is not None, "python-jose must be installed to use JwtAuth"
90+
self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm)
91+
self.secret_key = secret_key
8092
if places:
8193
assert places.issubset(
8294
{"header", "cookie"}
8395
), "only 'header'/'cookie' are supported"
84-
algorithm = algorithm.upper()
85-
assert (
86-
hasattr(jwt.ALGORITHMS, algorithm) is True # type: ignore[attr-defined]
87-
), f"{algorithm} algorithm is not supported by python-jose library"
88-
89-
self.secret_key = secret_key
9096

9197
self.places = places or {"header"}
9298
self.auto_error = auto_error
93-
self.algorithm = algorithm
9499
self.access_expires_delta = access_expires_delta or timedelta(minutes=15)
95100
self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31)
96101

102+
@property
103+
def algorithm(self):
104+
return self.jwt_backend.algorithm
105+
97106
@classmethod
98107
def from_other(
99108
cls,
@@ -112,30 +121,6 @@ def from_other(
112121
refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta,
113122
)
114123

115-
def _decode(self, token: str) -> Optional[Dict[str, Any]]:
116-
try:
117-
payload: Dict[str, Any] = jwt.decode(
118-
token,
119-
self.secret_key,
120-
algorithms=[self.algorithm],
121-
options={"leeway": 10},
122-
)
123-
return payload
124-
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
125-
if self.auto_error:
126-
raise HTTPException(
127-
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
128-
)
129-
else:
130-
return None
131-
except jwt.JWTError as e: # type: ignore[attr-defined]
132-
if self.auto_error:
133-
raise HTTPException(
134-
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
135-
)
136-
else:
137-
return None
138-
139124
def _generate_payload(
140125
self,
141126
subject: Dict[str, Any],
@@ -144,7 +129,6 @@ def _generate_payload(
144129
token_type: str,
145130
) -> Dict[str, Any]:
146131
now = utcnow()
147-
148132
return {
149133
"subject": subject.copy(), # main subject
150134
"type": token_type, # 'access' or 'refresh' token
@@ -172,8 +156,7 @@ async def _get_payload(
172156
return None
173157

174158
# Try to decode jwt token. auto_error on error
175-
payload = self._decode(token)
176-
return payload
159+
return self.jwt_backend.decode(token, self.secret_key, self.auto_error)
177160

178161
def create_access_token(
179162
self,
@@ -186,11 +169,7 @@ def create_access_token(
186169
to_encode = self._generate_payload(
187170
subject, expires_delta, unique_identifier, "access"
188171
)
189-
190-
jwt_encoded: str = jwt.encode(
191-
to_encode, self.secret_key, algorithm=self.algorithm
192-
)
193-
return jwt_encoded
172+
return self.jwt_backend.encode(to_encode, self.secret_key)
194173

195174
def create_refresh_token(
196175
self,
@@ -203,11 +182,7 @@ def create_refresh_token(
203182
to_encode = self._generate_payload(
204183
subject, expires_delta, unique_identifier, "refresh"
205184
)
206-
207-
jwt_encoded: str = jwt.encode(
208-
to_encode, self.secret_key, algorithm=self.algorithm
209-
)
210-
return jwt_encoded
185+
return self.jwt_backend.encode(to_encode, self.secret_key)
211186

212187
@staticmethod
213188
def set_access_cookie(
@@ -261,7 +236,7 @@ def __init__(
261236
secret_key: str,
262237
places: Optional[Set[str]] = None,
263238
auto_error: bool = True,
264-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
239+
algorithm: Optional[str] = None,
265240
access_expires_delta: Optional[timedelta] = None,
266241
refresh_expires_delta: Optional[timedelta] = None,
267242
):
@@ -293,7 +268,7 @@ def __init__(
293268
self,
294269
secret_key: str,
295270
auto_error: bool = True,
296-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
271+
algorithm: Optional[str] = None,
297272
access_expires_delta: Optional[timedelta] = None,
298273
refresh_expires_delta: Optional[timedelta] = None,
299274
):
@@ -317,7 +292,7 @@ def __init__(
317292
self,
318293
secret_key: str,
319294
auto_error: bool = True,
320-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
295+
algorithm: Optional[str] = None,
321296
access_expires_delta: Optional[timedelta] = None,
322297
refresh_expires_delta: Optional[timedelta] = None,
323298
):
@@ -342,7 +317,7 @@ def __init__(
342317
self,
343318
secret_key: str,
344319
auto_error: bool = True,
345-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
320+
algorithm: Optional[str] = None,
346321
access_expires_delta: Optional[timedelta] = None,
347322
refresh_expires_delta: Optional[timedelta] = None,
348323
):
@@ -372,7 +347,7 @@ def __init__(
372347
secret_key: str,
373348
places: Optional[Set[str]] = None,
374349
auto_error: bool = True,
375-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
350+
algorithm: Optional[str] = None,
376351
access_expires_delta: Optional[timedelta] = None,
377352
refresh_expires_delta: Optional[timedelta] = None,
378353
):
@@ -414,7 +389,7 @@ def __init__(
414389
self,
415390
secret_key: str,
416391
auto_error: bool = True,
417-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
392+
algorithm: Optional[str] = None,
418393
access_expires_delta: Optional[timedelta] = None,
419394
refresh_expires_delta: Optional[timedelta] = None,
420395
):
@@ -438,7 +413,7 @@ def __init__(
438413
self,
439414
secret_key: str,
440415
auto_error: bool = True,
441-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
416+
algorithm: Optional[str] = None,
442417
access_expires_delta: Optional[timedelta] = None,
443418
refresh_expires_delta: Optional[timedelta] = None,
444419
):
@@ -463,7 +438,7 @@ def __init__(
463438
self,
464439
secret_key: str,
465440
auto_error: bool = True,
466-
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
441+
algorithm: Optional[str] = None,
467442
access_expires_delta: Optional[timedelta] = None,
468443
refresh_expires_delta: Optional[timedelta] = None,
469444
):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
try:
2+
from .authlib_backend import AuthlibJWTBackend
3+
except ImportError:
4+
AuthlibJWTBackend = None
5+
6+
try:
7+
from .python_jose_backend import PythonJoseJWTBackend
8+
except ImportError:
9+
PythonJoseJWTBackend = None
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from abc import ABCMeta, abstractmethod, abstractproperty
2+
from typing import Any, Dict, Optional, Self
3+
4+
5+
6+
class AbstractJWTBackend(metaclass=ABCMeta):
7+
8+
# simple "SingletonArgs" implementation to keep a JWTBackend per algorithm
9+
_instances = {}
10+
11+
def __new__(cls, algorithm) -> Self:
12+
instance_key = (cls, algorithm)
13+
if instance_key not in cls._instances:
14+
cls._instances[instance_key] = super(AbstractJWTBackend, cls).__new__(cls)
15+
return cls._instances[instance_key]
16+
17+
@abstractmethod
18+
def __init__(self, algorithm) -> None:
19+
pass
20+
21+
@abstractproperty
22+
def default_algorithm(self) -> str:
23+
pass
24+
25+
@abstractmethod
26+
def encode(self, to_encode, secret_key) -> str:
27+
pass
28+
29+
@abstractmethod
30+
def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]:
31+
pass
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from fastapi import HTTPException
2+
from typing import Any, Dict, Optional
3+
from starlette.status import HTTP_401_UNAUTHORIZED
4+
5+
from authlib.jose import JsonWebSignature, JsonWebToken
6+
from authlib.jose.errors import (
7+
DecodeError, ExpiredTokenError, InvalidClaimError, InvalidTokenError
8+
)
9+
from .abstract_backend import AbstractJWTBackend
10+
11+
12+
class AuthlibJWTBackend(AbstractJWTBackend):
13+
14+
def __init__(self, algorithm) -> None:
15+
self.algorithm = algorithm if algorithm is not None else self.default_algorithm
16+
# from https://github.com/lepture/authlib/blob/85f9ff/authlib/jose/__init__.py#L45
17+
valid_algorithms = JsonWebSignature.ALGORITHMS_REGISTRY.keys()
18+
assert (
19+
self.algorithm in valid_algorithms
20+
), f"{self.algorithm} algorithm is not supported by authlib"
21+
self.jwt = JsonWebToken(algorithms=[self.algorithm])
22+
23+
@property
24+
def default_algorithm(self) -> str:
25+
return "HS256"
26+
27+
def encode(self, to_encode, secret_key) -> str:
28+
token = self.jwt.encode(header={"alg": self.algorithm}, payload=to_encode, key=secret_key)
29+
return token.decode() # convert to string
30+
31+
def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]:
32+
try:
33+
payload = self.jwt.decode(token, secret_key)
34+
payload.validate(leeway=10)
35+
return dict(payload)
36+
except ExpiredTokenError as e: # type: ignore[attr-defined]
37+
if auto_error:
38+
raise HTTPException(
39+
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
40+
)
41+
else:
42+
return None
43+
except (InvalidClaimError,
44+
InvalidTokenError,
45+
DecodeError) as e: # type: ignore[attr-defined]
46+
if auto_error:
47+
raise HTTPException(
48+
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
49+
)
50+
else:
51+
return None
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from fastapi import HTTPException
2+
from typing import Any, Dict, Optional
3+
from starlette.status import HTTP_401_UNAUTHORIZED
4+
5+
from jose import jwt
6+
7+
from .abstract_backend import AbstractJWTBackend
8+
9+
10+
class PythonJoseJWTBackend(AbstractJWTBackend):
11+
12+
def __init__(self, algorithm) -> None:
13+
self.algorithm = algorithm if algorithm is not None else self.default_algorithm
14+
assert (
15+
hasattr(jwt.ALGORITHMS, self.algorithm) is True # type: ignore[attr-defined]
16+
), f"{algorithm} algorithm is not supported by python-jose library"
17+
18+
@property
19+
def default_algorithm(self) -> str:
20+
return jwt.ALGORITHMS.HS256
21+
22+
def encode(self, to_encode, secret_key) -> str:
23+
return jwt.encode(to_encode, secret_key, algorithm=self.algorithm)
24+
25+
def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]:
26+
try:
27+
payload: Dict[str, Any] = jwt.decode(
28+
token,
29+
secret_key,
30+
algorithms=[self.algorithm],
31+
options={"leeway": 10},
32+
)
33+
return payload
34+
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
35+
if auto_error:
36+
raise HTTPException(
37+
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
38+
)
39+
else:
40+
return None
41+
except jwt.JWTError as e: # type: ignore[attr-defined]
42+
if auto_error:
43+
raise HTTPException(
44+
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
45+
)
46+
else:
47+
return None

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ classifiers = [
2727

2828
dependencies = [
2929
"fastapi >=0.50.0",
30-
"python-jose[cryptography] >=3.3.0"
3130
]
3231

3332

@@ -37,7 +36,15 @@ documentation = "https://k4black.github.io/fastapi-jwt/"
3736

3837

3938
[project.optional-dependencies]
39+
authlib = [
40+
"Authlib >=1.3.0"
41+
]
42+
python_jose = [
43+
"python-jose[cryptography] >=3.3.0"
44+
]
4045
test = [
46+
"Authlib >=1.3.0",
47+
"python-jose[cryptography] >=3.3.0",
4148
"httpx >=0.23.0,<1.0.0",
4249
"pytest >=7.0.0,<9.0.0",
4350
"pytest-cov >=4.0.0,<5.0.0",

0 commit comments

Comments
 (0)