Skip to content

Commit e75df1a

Browse files
committed
Add ORCID client service
1 parent 531e9e3 commit e75df1a

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed

h/models/user_identity.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
from enum import StrEnum
2+
13
import sqlalchemy as sa
24

35
from h.db import Base
46

57

8+
class IdentityProvider(StrEnum):
9+
ORCID = "orcid.org"
10+
11+
612
class UserIdentity(Base):
713
__tablename__ = "user_identity"
814
__table_args__ = (sa.UniqueConstraint("provider", "provider_unique_id"),)

h/services/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from h.services.mention import MentionService
1919
from h.services.notification import NotificationService
2020
from h.services.openid_client import OpenIDClientService
21+
from h.services.orcid_client import ORCIDClientService
2122
from h.services.subscription import SubscriptionService
2223
from h.services.task_done import TaskDoneService
2324

@@ -187,3 +188,6 @@ def includeme(config): # pragma: no cover # noqa: PLR0915
187188
config.register_service_factory(
188189
"h.services.openid_client.factory", iface=OpenIDClientService
189190
)
191+
config.register_service_factory(
192+
"h.services.orcid_client.factory", iface=ORCIDClientService
193+
)

h/services/orcid_client.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import logging
2+
3+
from sqlalchemy import select
4+
from sqlalchemy.orm import Session
5+
6+
from h.models import User, UserIdentity
7+
from h.models.user_identity import IdentityProvider
8+
from h.services.jwt import JWTService
9+
from h.services.openid_client import OpenIDClientService
10+
from h.services.user import UserService
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class ORCIDClientService:
16+
def __init__( # noqa: PLR0913
17+
self,
18+
db: Session,
19+
host: str,
20+
client_id: str,
21+
client_secret: str,
22+
redirect_uri: str,
23+
openid_client_service: OpenIDClientService,
24+
user_service: UserService,
25+
) -> None:
26+
self._db = db
27+
self._host = host
28+
self._client_id = client_id
29+
self._client_secret = client_secret
30+
self._redirect_uri = redirect_uri
31+
self._openid_client_service = openid_client_service
32+
self._user_service = user_service
33+
34+
def _get_id_token(self, authorization_code: str) -> str:
35+
return self._openid_client_service.get_id_token(
36+
token_url=self.token_url,
37+
redirect_uri=self._redirect_uri,
38+
auth=(self._client_id, self._client_secret),
39+
authorization_code=authorization_code,
40+
)
41+
42+
def get_orcid(self, authorization_code: str) -> str | None:
43+
id_token = self._get_id_token(authorization_code)
44+
decoded_id_token = JWTService.decode_token(id_token, self.key_set_url)
45+
return decoded_id_token.get("sub")
46+
47+
def add_identity(self, user: User, orcid: str) -> None:
48+
identity = UserIdentity(
49+
user=user,
50+
provider=IdentityProvider.ORCID,
51+
provider_unique_id=orcid,
52+
)
53+
self._db.add(identity)
54+
55+
def get_identity(self, user: User) -> UserIdentity | None:
56+
stmt = select(UserIdentity).where(
57+
UserIdentity.user_id == user.id,
58+
UserIdentity.provider == IdentityProvider.ORCID,
59+
)
60+
return self._db.execute(stmt).scalar()
61+
62+
@property
63+
def token_url(self) -> str:
64+
return self._api_url("oauth/token")
65+
66+
@property
67+
def key_set_url(self) -> str:
68+
return self._api_url("oauth/jwks")
69+
70+
def orcid_url(self, orcid: str | None) -> str | None:
71+
if not orcid:
72+
return None
73+
return self._api_url(orcid)
74+
75+
def _api_url(self, path: str) -> str:
76+
return f"https://{self._host}/{path}"
77+
78+
79+
def factory(_context, request) -> ORCIDClientService:
80+
settings = request.registry.settings
81+
82+
return ORCIDClientService(
83+
db=request.db,
84+
host=settings["orcid_host"],
85+
client_id=settings["orcid_client_id"],
86+
client_secret=settings["orcid_client_secret"],
87+
redirect_uri=request.route_url("orcid.oauth.callback"),
88+
openid_client_service=request.find_service(OpenIDClientService),
89+
user_service=request.find_service(name="user"),
90+
)
+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from unittest.mock import sentinel
2+
3+
import pytest
4+
from sqlalchemy import select
5+
6+
from h.models import UserIdentity
7+
from h.models.user_identity import IdentityProvider
8+
from h.services.orcid_client import ORCIDClientService, factory
9+
10+
11+
class TestORCIDClientService:
12+
def test_get_orcid(self, service, openid_client_service, JWTService):
13+
openid_client_service.get_id_token.return_value = sentinel.id_token
14+
JWTService.decode_token.return_value = {"sub": sentinel.orcid}
15+
16+
orcid = service.get_orcid(sentinel.authorization_code)
17+
18+
assert orcid == sentinel.orcid
19+
openid_client_service.get_id_token.assert_called_once_with(
20+
token_url=service.token_url,
21+
redirect_uri=sentinel.redirect_uri,
22+
auth=(sentinel.client_id, sentinel.client_secret),
23+
authorization_code=sentinel.authorization_code,
24+
)
25+
JWTService.decode_token.assert_called_once_with(
26+
sentinel.id_token, service.key_set_url
27+
)
28+
29+
def test_get_orcid_returns_none_if_sub_missing(
30+
self, service, openid_client_service, JWTService
31+
):
32+
openid_client_service.get_id_token.return_value = sentinel.id_token
33+
JWTService.decode_token.return_value = {}
34+
35+
assert service.get_orcid(sentinel.authorization_code) is None
36+
37+
def test_add_identity(self, service, db_session, user):
38+
orcid = "1111-1111-1111-1111"
39+
40+
service.add_identity(user, orcid)
41+
42+
stmt = select(UserIdentity).where(
43+
UserIdentity.user == user,
44+
UserIdentity.provider == IdentityProvider.ORCID,
45+
UserIdentity.provider_unique_id == orcid,
46+
)
47+
assert db_session.execute(stmt).scalar() is not None
48+
49+
def test_get_identity(self, service, user, user_identity):
50+
assert service.get_identity(user) == user_identity
51+
52+
def test_get_identity_without_identities(self, service, user):
53+
user.identities = []
54+
55+
assert service.get_identity(user) is None
56+
57+
def test_orcid_url_with_empty_orcid(self, service):
58+
assert service.orcid_url("") is None
59+
60+
@pytest.fixture
61+
def user(self, factories):
62+
return factories.User()
63+
64+
@pytest.fixture
65+
def user_identity(self, user, db_session):
66+
identity = UserIdentity(
67+
user=user,
68+
provider=IdentityProvider.ORCID,
69+
provider_unique_id="0000-0000-0000-0000",
70+
)
71+
db_session.add(identity)
72+
db_session.flush()
73+
return identity
74+
75+
@pytest.fixture
76+
def service(self, db_session, openid_client_service, user_service):
77+
return ORCIDClientService(
78+
db=db_session,
79+
host=IdentityProvider.ORCID,
80+
client_id=sentinel.client_id,
81+
client_secret=sentinel.client_secret,
82+
redirect_uri=sentinel.redirect_uri,
83+
openid_client_service=openid_client_service,
84+
user_service=user_service,
85+
)
86+
87+
@pytest.fixture(autouse=True)
88+
def JWTService(self, patch):
89+
return patch("h.services.orcid_client.JWTService")
90+
91+
92+
class TestFactory:
93+
def test_it(
94+
self, pyramid_request, ORCIDClientService, openid_client_service, user_service
95+
):
96+
service = factory(sentinel.context, pyramid_request)
97+
98+
ORCIDClientService.assert_called_once_with(
99+
db=pyramid_request.db,
100+
host=IdentityProvider.ORCID,
101+
client_id=sentinel.client_id,
102+
client_secret=sentinel.client_secret,
103+
redirect_uri=sentinel.redirect_uri,
104+
openid_client_service=openid_client_service,
105+
user_service=user_service,
106+
)
107+
assert service == ORCIDClientService.return_value
108+
109+
@pytest.fixture(autouse=True)
110+
def ORCIDClientService(self, patch):
111+
return patch("h.services.orcid_client.ORCIDClientService")
112+
113+
@pytest.fixture
114+
def pyramid_request(self, pyramid_request, mocker):
115+
pyramid_request.registry.settings.update(
116+
{
117+
"orcid_host": IdentityProvider.ORCID,
118+
"orcid_client_id": sentinel.client_id,
119+
"orcid_client_secret": sentinel.client_secret,
120+
}
121+
)
122+
pyramid_request.route_url = mocker.Mock(return_value=sentinel.redirect_uri)
123+
return pyramid_request

0 commit comments

Comments
 (0)