88from fastapi .responses import Response
99from fastapi .security import APIKeyCookie , HTTPBearer
1010from starlette .status import HTTP_401_UNAUTHORIZED
11+ from .jwt_backends import AuthlibJWTBackend , PythonJoseJWTBackend
1112
12- try :
13- from jose import jwt
14- except ImportError : # pragma: nocover
15- jwt = None # type: ignore[assignment]
13+
14+ DEFAULT_JWT_BACKEND = None
15+
16+
17+ def define_default_jwt_backend (cls ):
18+ global DEFAULT_JWT_BACKEND
19+ DEFAULT_JWT_BACKEND = cls
20+
21+
22+ if AuthlibJWTBackend is not None :
23+ define_default_jwt_backend (AuthlibJWTBackend )
24+ elif PythonJoseJWTBackend is not None :
25+ define_default_jwt_backend (PythonJoseJWTBackend )
1626
1727
1828def utcnow ():
@@ -27,6 +37,7 @@ def utcnow():
2737
2838
2939__all__ = [
40+ "define_default_jwt_backend" ,
3041 "JwtAuthorizationCredentials" ,
3142 "JwtAccessBearer" ,
3243 "JwtAccessCookie" ,
@@ -72,28 +83,26 @@ def __init__(
7283 secret_key : str ,
7384 places : Optional [Set [str ]] = None ,
7485 auto_error : bool = True ,
75- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
86+ algorithm : Optional [ str ] = None ,
7687 access_expires_delta : Optional [timedelta ] = None ,
7788 refresh_expires_delta : Optional [timedelta ] = None ,
7889 ):
79- assert jwt is not None , "python-jose must be installed to use JwtAuth"
90+ self .jwt_backend = DEFAULT_JWT_BACKEND (algorithm )
91+ self .secret_key = secret_key
8092 if places :
8193 assert places .issubset (
8294 {"header" , "cookie" }
8395 ), "only 'header'/'cookie' are supported"
84- algorithm = algorithm .upper ()
85- assert (
86- hasattr (jwt .ALGORITHMS , algorithm ) is True # type: ignore[attr-defined]
87- ), f"{ algorithm } algorithm is not supported by python-jose library"
88-
89- self .secret_key = secret_key
9096
9197 self .places = places or {"header" }
9298 self .auto_error = auto_error
93- self .algorithm = algorithm
9499 self .access_expires_delta = access_expires_delta or timedelta (minutes = 15 )
95100 self .refresh_expires_delta = refresh_expires_delta or timedelta (days = 31 )
96101
102+ @property
103+ def algorithm (self ):
104+ return self .jwt_backend .algorithm
105+
97106 @classmethod
98107 def from_other (
99108 cls ,
@@ -112,30 +121,6 @@ def from_other(
112121 refresh_expires_delta = refresh_expires_delta or other .refresh_expires_delta ,
113122 )
114123
115- def _decode (self , token : str ) -> Optional [Dict [str , Any ]]:
116- try :
117- payload : Dict [str , Any ] = jwt .decode (
118- token ,
119- self .secret_key ,
120- algorithms = [self .algorithm ],
121- options = {"leeway" : 10 },
122- )
123- return payload
124- except jwt .ExpiredSignatureError as e : # type: ignore[attr-defined]
125- if self .auto_error :
126- raise HTTPException (
127- status_code = HTTP_401_UNAUTHORIZED , detail = f"Token time expired: { e } "
128- )
129- else :
130- return None
131- except jwt .JWTError as e : # type: ignore[attr-defined]
132- if self .auto_error :
133- raise HTTPException (
134- status_code = HTTP_401_UNAUTHORIZED , detail = f"Wrong token: { e } "
135- )
136- else :
137- return None
138-
139124 def _generate_payload (
140125 self ,
141126 subject : Dict [str , Any ],
@@ -144,7 +129,6 @@ def _generate_payload(
144129 token_type : str ,
145130 ) -> Dict [str , Any ]:
146131 now = utcnow ()
147-
148132 return {
149133 "subject" : subject .copy (), # main subject
150134 "type" : token_type , # 'access' or 'refresh' token
@@ -172,8 +156,7 @@ async def _get_payload(
172156 return None
173157
174158 # Try to decode jwt token. auto_error on error
175- payload = self ._decode (token )
176- return payload
159+ return self .jwt_backend .decode (token , self .secret_key , self .auto_error )
177160
178161 def create_access_token (
179162 self ,
@@ -186,11 +169,7 @@ def create_access_token(
186169 to_encode = self ._generate_payload (
187170 subject , expires_delta , unique_identifier , "access"
188171 )
189-
190- jwt_encoded : str = jwt .encode (
191- to_encode , self .secret_key , algorithm = self .algorithm
192- )
193- return jwt_encoded
172+ return self .jwt_backend .encode (to_encode , self .secret_key )
194173
195174 def create_refresh_token (
196175 self ,
@@ -203,11 +182,7 @@ def create_refresh_token(
203182 to_encode = self ._generate_payload (
204183 subject , expires_delta , unique_identifier , "refresh"
205184 )
206-
207- jwt_encoded : str = jwt .encode (
208- to_encode , self .secret_key , algorithm = self .algorithm
209- )
210- return jwt_encoded
185+ return self .jwt_backend .encode (to_encode , self .secret_key )
211186
212187 @staticmethod
213188 def set_access_cookie (
@@ -261,7 +236,7 @@ def __init__(
261236 secret_key : str ,
262237 places : Optional [Set [str ]] = None ,
263238 auto_error : bool = True ,
264- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
239+ algorithm : Optional [ str ] = None ,
265240 access_expires_delta : Optional [timedelta ] = None ,
266241 refresh_expires_delta : Optional [timedelta ] = None ,
267242 ):
@@ -293,7 +268,7 @@ def __init__(
293268 self ,
294269 secret_key : str ,
295270 auto_error : bool = True ,
296- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
271+ algorithm : Optional [ str ] = None ,
297272 access_expires_delta : Optional [timedelta ] = None ,
298273 refresh_expires_delta : Optional [timedelta ] = None ,
299274 ):
@@ -317,7 +292,7 @@ def __init__(
317292 self ,
318293 secret_key : str ,
319294 auto_error : bool = True ,
320- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
295+ algorithm : Optional [ str ] = None ,
321296 access_expires_delta : Optional [timedelta ] = None ,
322297 refresh_expires_delta : Optional [timedelta ] = None ,
323298 ):
@@ -342,7 +317,7 @@ def __init__(
342317 self ,
343318 secret_key : str ,
344319 auto_error : bool = True ,
345- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
320+ algorithm : Optional [ str ] = None ,
346321 access_expires_delta : Optional [timedelta ] = None ,
347322 refresh_expires_delta : Optional [timedelta ] = None ,
348323 ):
@@ -372,7 +347,7 @@ def __init__(
372347 secret_key : str ,
373348 places : Optional [Set [str ]] = None ,
374349 auto_error : bool = True ,
375- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
350+ algorithm : Optional [ str ] = None ,
376351 access_expires_delta : Optional [timedelta ] = None ,
377352 refresh_expires_delta : Optional [timedelta ] = None ,
378353 ):
@@ -414,7 +389,7 @@ def __init__(
414389 self ,
415390 secret_key : str ,
416391 auto_error : bool = True ,
417- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
392+ algorithm : Optional [ str ] = None ,
418393 access_expires_delta : Optional [timedelta ] = None ,
419394 refresh_expires_delta : Optional [timedelta ] = None ,
420395 ):
@@ -438,7 +413,7 @@ def __init__(
438413 self ,
439414 secret_key : str ,
440415 auto_error : bool = True ,
441- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
416+ algorithm : Optional [ str ] = None ,
442417 access_expires_delta : Optional [timedelta ] = None ,
443418 refresh_expires_delta : Optional [timedelta ] = None ,
444419 ):
@@ -463,7 +438,7 @@ def __init__(
463438 self ,
464439 secret_key : str ,
465440 auto_error : bool = True ,
466- algorithm : str = jwt . ALGORITHMS . HS256 , # type: ignore[attr-defined]
441+ algorithm : Optional [ str ] = None ,
467442 access_expires_delta : Optional [timedelta ] = None ,
468443 refresh_expires_delta : Optional [timedelta ] = None ,
469444 ):
0 commit comments