|
| 1 | +import logging |
| 2 | + |
| 3 | +from sqlalchemy.orm import Session |
| 4 | + |
| 5 | +from h.models import User, UserIdentity |
| 6 | +from h.models.user_identity import ProviderHost |
| 7 | +from h.services.jwt import JWTService |
| 8 | +from h.services.openid_client import OpenIDClientService |
| 9 | +from h.services.user import UserService |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +class ORCIDClientService: |
| 15 | + def __init__( # noqa: PLR0913 |
| 16 | + self, |
| 17 | + session: Session, |
| 18 | + host: str, |
| 19 | + client_id: str, |
| 20 | + client_secret: str, |
| 21 | + redirect_uri: str, |
| 22 | + openid_client_service: OpenIDClientService, |
| 23 | + user_service: UserService, |
| 24 | + ) -> None: |
| 25 | + self._session = session |
| 26 | + self._host = host |
| 27 | + self._client_id = client_id |
| 28 | + self._client_secret = client_secret |
| 29 | + self._redirect_uri = redirect_uri |
| 30 | + self._openid_client_service = openid_client_service |
| 31 | + self._user_service = user_service |
| 32 | + |
| 33 | + def _get_id_token(self, authorization_code: str) -> str: |
| 34 | + return self._openid_client_service.get_id_token( |
| 35 | + token_url=self.token_url, |
| 36 | + redirect_uri=self._redirect_uri, |
| 37 | + auth=(self._client_id, self._client_secret), |
| 38 | + authorization_code=authorization_code, |
| 39 | + ) |
| 40 | + |
| 41 | + def get_orcid(self, authorization_code: str) -> str | None: |
| 42 | + id_token = self._get_id_token(authorization_code) |
| 43 | + decoded_id_token = JWTService.decode_token(id_token, self.key_set_url) |
| 44 | + return decoded_id_token.get("sub") |
| 45 | + |
| 46 | + def add_identity(self, user: User, orcid: str) -> None: |
| 47 | + identity = UserIdentity( |
| 48 | + user=user, |
| 49 | + provider=ProviderHost.ORCID, |
| 50 | + provider_unique_id=orcid, |
| 51 | + ) |
| 52 | + self._session.add(identity) |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def get_identity(user: User) -> UserIdentity | None: |
| 56 | + for identity in user.identities: |
| 57 | + if identity.provider == ProviderHost.ORCID: |
| 58 | + return identity |
| 59 | + return None |
| 60 | + |
| 61 | + @property |
| 62 | + def token_url(self) -> str: |
| 63 | + return self._api_url("oauth/token") |
| 64 | + |
| 65 | + @property |
| 66 | + def key_set_url(self) -> str: |
| 67 | + return self._api_url("oauth/jwks") |
| 68 | + |
| 69 | + def orcid_url(self, orcid: str) -> str | None: |
| 70 | + return self._api_url(orcid) if orcid else None |
| 71 | + |
| 72 | + def _api_url(self, path: str) -> str: |
| 73 | + return f"https://{self._host}/{path}" |
| 74 | + |
| 75 | + |
| 76 | +def factory(_context, request) -> ORCIDClientService: |
| 77 | + settings = request.registry.settings |
| 78 | + |
| 79 | + return ORCIDClientService( |
| 80 | + session=request.db, |
| 81 | + host=settings["orcid_host"], |
| 82 | + client_id=settings["orcid_client_id"], |
| 83 | + client_secret=settings["orcid_client_secret"], |
| 84 | + redirect_uri=request.route_url("orcid.oauth.callback"), |
| 85 | + openid_client_service=request.find_service(OpenIDClientService), |
| 86 | + user_service=request.find_service(name="user"), |
| 87 | + ) |
0 commit comments