diff --git a/backend/main.py b/backend/main.py index 70e7f837d2..29b10d91c1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,7 +10,6 @@ from pyinstrument import Profiler from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from starlette.middleware.authentication import AuthenticationMiddleware - from backend.config import settings from backend.db import db_connection from backend.exceptions import BadRequest, Conflict, Forbidden, NotFound, Unauthorized @@ -55,16 +54,21 @@ async def lifespan(app): # Custom exception handler for invalid token and logout. @_app.exception_handler(HTTPException) async def custom_http_exception_handler(request: Request, exc: HTTPException): - if exc.status_code == 401 and "InvalidToken" in exc.detail.get("SubCode", ""): - return JSONResponse( - content={ - "Error": exc.detail["Error"], - "SubCode": exc.detail["SubCode"], - }, - status_code=exc.status_code, - headers={"WWW-Authenticate": "Bearer"}, - ) - + try: + if exc.status_code == 401 and "InvalidToken" in exc.detail.get( + "SubCode", "" + ): + return JSONResponse( + content={ + "Error": exc.detail["Error"], + "SubCode": exc.detail["SubCode"], + }, + status_code=exc.status_code, + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + logging.debug(f"Exception while handling custom HTTPException: {e}") + pass if isinstance(exc.detail, dict) and "error" in exc.detail: error_response = exc.detail else: diff --git a/backend/services/users/authentication_service.py b/backend/services/users/authentication_service.py index ee6d71b53a..51685c7fe4 100644 --- a/backend/services/users/authentication_service.py +++ b/backend/services/users/authentication_service.py @@ -70,18 +70,18 @@ def verify_token(token): class TokenAuthBackend(AuthenticationBackend): async def authenticate(self, conn): if "authorization" not in conn.headers: - return + return None auth = conn.headers["authorization"] try: scheme, credentials = auth.split() if scheme.lower() != "token": - return + return None try: decoded_token = base64.b64decode(credentials).decode("ascii") except UnicodeDecodeError: logger.debug("Unable to decode token") - return False + return None except (ValueError, UnicodeDecodeError, binascii.Error): raise AuthenticationError("Invalid auth credentials") @@ -90,7 +90,7 @@ async def authenticate(self, conn): ) if not valid_token: logger.debug("Token not valid.") - return + return None tm.authenticated_user_id = user_id return AuthCredentials(["authenticated"]), SimpleUser(user_id) @@ -251,7 +251,6 @@ async def login_required( raise AuthenticationError("Invalid auth credentials") valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 604800) if not valid_token: - logger.debug("Token not valid") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={"Error": "Token is expired or invalid", "SubCode": "InvalidToken"}, @@ -275,12 +274,18 @@ async def login_required_optional( decoded_token = base64.b64decode(credentials).decode("ascii") except UnicodeDecodeError: logger.debug("Unable to decode token") - raise HTTPException(status_code=401, detail="Invalid token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={ + "Error": "Token is expired or invalid", + "SubCode": "InvalidToken", + }, + headers={"WWW-Authenticate": "Bearer"}, + ) except (ValueError, UnicodeDecodeError, binascii.Error): raise AuthenticationError("Invalid auth credentials") valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 604800) if not valid_token: - logger.debug("Token not valid") return None return AuthUserDTO(id=user_id)