Skip to content

Add ORCID client service #9565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions h/models/user_identity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from enum import StrEnum

import sqlalchemy as sa

from h.db import Base


class IdentityProvider(StrEnum):
ORCID = "orcid.org"


class UserIdentity(Base):
__tablename__ = "user_identity"
__table_args__ = (sa.UniqueConstraint("provider", "provider_unique_id"),)
Expand Down
4 changes: 4 additions & 0 deletions h/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from h.services.mention import MentionService
from h.services.notification import NotificationService
from h.services.openid_client import OpenIDClientService
from h.services.orcid_client import ORCIDClientService
from h.services.subscription import SubscriptionService
from h.services.task_done import TaskDoneService

Expand Down Expand Up @@ -187,3 +188,6 @@ def includeme(config): # pragma: no cover # noqa: PLR0915
config.register_service_factory(
"h.services.openid_client.factory", iface=OpenIDClientService
)
config.register_service_factory(
"h.services.orcid_client.factory", iface=ORCIDClientService
)
88 changes: 88 additions & 0 deletions h/services/orcid_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging

from sqlalchemy import select
from sqlalchemy.orm import Session

from h.models import User, UserIdentity
from h.models.user_identity import IdentityProvider
from h.services.jwt import JWTService
from h.services.openid_client import OpenIDClientService
from h.services.user import UserService

logger = logging.getLogger(__name__)


class ORCIDClientService:
def __init__( # noqa: PLR0913
self,
db: Session,
host: str,
client_id: str,
client_secret: str,
redirect_uri: str,
openid_client_service: OpenIDClientService,
user_service: UserService,
) -> None:
self._db = db
self._host = host
self._client_id = client_id
self._client_secret = client_secret
self._redirect_uri = redirect_uri
self._openid_client_service = openid_client_service
self._user_service = user_service

def _get_id_token(self, authorization_code: str) -> str:
return self._openid_client_service.get_id_token(
token_url=self.token_url,
redirect_uri=self._redirect_uri,
auth=(self._client_id, self._client_secret),
authorization_code=authorization_code,
)

def get_orcid(self, authorization_code: str) -> str | None:
id_token = self._get_id_token(authorization_code)
decoded_id_token = JWTService.decode_token(id_token, self.key_set_url)
return decoded_id_token.get("sub")

def add_identity(self, user: User, orcid: str) -> None:
identity = UserIdentity(
user=user,
provider=IdentityProvider.ORCID,
provider_unique_id=orcid,
)
self._db.add(identity)

def get_identity(self, user: User) -> UserIdentity | None:
stmt = select(UserIdentity).where(
UserIdentity.user_id == user.id,
UserIdentity.provider == IdentityProvider.ORCID,
)
return self._db.execute(stmt).scalar()

@property
def token_url(self) -> str:
return self._api_url("oauth/token")

@property
def key_set_url(self) -> str:
return self._api_url("oauth/jwks")

def orcid_url(self, orcid: str | None) -> str | None:
return self._api_url(orcid) if orcid else None

def _api_url(self, path: str) -> str:
return f"https://{self._host}/{path}"


def factory(_context, request) -> ORCIDClientService:
settings = request.registry.settings

return ORCIDClientService(
db=request.db,
host=settings["orcid_host"],
client_id=settings["orcid_client_id"],
client_secret=settings["orcid_client_secret"],
redirect_uri=request.route_url("orcid.oauth.callback"),
openid_client_service=request.find_service(OpenIDClientService),
user_service=request.find_service(name="user"),
)
6 changes: 0 additions & 6 deletions tests/unit/h/schemas/oauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def test_validate(self, schema):
],
)
def test_validate_with_invalid_data(self, data, expected_error, schema):
data = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
}

with pytest.raises(ValidationError, match=expected_error):
schema.validate(data)

Expand Down
123 changes: 123 additions & 0 deletions tests/unit/h/services/orcid_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from unittest.mock import sentinel

import pytest
from sqlalchemy import select

from h.models import UserIdentity
from h.models.user_identity import IdentityProvider
from h.services.orcid_client import ORCIDClientService, factory


class TestORCIDClientService:
def test_get_orcid(self, service, openid_client_service, JWTService):
openid_client_service.get_id_token.return_value = sentinel.id_token
JWTService.decode_token.return_value = {"sub": sentinel.orcid}

orcid = service.get_orcid(sentinel.authorization_code)

assert orcid == sentinel.orcid
openid_client_service.get_id_token.assert_called_once_with(
token_url=service.token_url,
redirect_uri=sentinel.redirect_uri,
auth=(sentinel.client_id, sentinel.client_secret),
authorization_code=sentinel.authorization_code,
)
JWTService.decode_token.assert_called_once_with(
sentinel.id_token, service.key_set_url
)

def test_get_orcid_returns_none_if_sub_missing(
self, service, openid_client_service, JWTService
):
openid_client_service.get_id_token.return_value = sentinel.id_token
JWTService.decode_token.return_value = {}

assert service.get_orcid(sentinel.authorization_code) is None

def test_add_identity(self, service, db_session, user):
orcid = "1111-1111-1111-1111"

service.add_identity(user, orcid)

stmt = select(UserIdentity).where(
UserIdentity.user == user,
UserIdentity.provider == IdentityProvider.ORCID,
UserIdentity.provider_unique_id == orcid,
)
assert db_session.execute(stmt).scalar() is not None

def test_get_identity(self, service, user, user_identity):
assert service.get_identity(user) == user_identity

def test_get_identity_without_identities(self, service, user):
user.identities = []

assert service.get_identity(user) is None

def test_orcid_url_with_empty_orcid(self, service):
assert service.orcid_url("") is None

@pytest.fixture
def user(self, factories):
return factories.User()

@pytest.fixture
def user_identity(self, user, db_session):
identity = UserIdentity(
user=user,
provider=IdentityProvider.ORCID,
provider_unique_id="0000-0000-0000-0000",
)
db_session.add(identity)
db_session.flush()
return identity

@pytest.fixture
def service(self, db_session, openid_client_service, user_service):
return ORCIDClientService(
db=db_session,
host=IdentityProvider.ORCID,
client_id=sentinel.client_id,
client_secret=sentinel.client_secret,
redirect_uri=sentinel.redirect_uri,
openid_client_service=openid_client_service,
user_service=user_service,
)

@pytest.fixture(autouse=True)
def JWTService(self, patch):
return patch("h.services.orcid_client.JWTService")


class TestFactory:
def test_it(
self, pyramid_request, ORCIDClientService, openid_client_service, user_service
):
service = factory(sentinel.context, pyramid_request)

ORCIDClientService.assert_called_once_with(
db=pyramid_request.db,
host=IdentityProvider.ORCID,
client_id=sentinel.client_id,
client_secret=sentinel.client_secret,
redirect_uri=sentinel.redirect_uri,
openid_client_service=openid_client_service,
user_service=user_service,
)
assert service == ORCIDClientService.return_value

@pytest.fixture(autouse=True)
def ORCIDClientService(self, patch):
return patch("h.services.orcid_client.ORCIDClientService")

@pytest.fixture
def pyramid_request(self, pyramid_request, mocker):
pyramid_request.registry.settings.update(
{
"orcid_host": IdentityProvider.ORCID,
"orcid_client_id": sentinel.client_id,
"orcid_client_secret": sentinel.client_secret,
}
)
pyramid_request.route_url = mocker.Mock(return_value=sentinel.redirect_uri)
return pyramid_request
Loading