Skip to content

Commit d18f145

Browse files
committed
Add orcid client service
1 parent c2cf667 commit d18f145

File tree

9 files changed

+532
-1
lines changed

9 files changed

+532
-1
lines changed

h/models/user_identity.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
from enum import StrEnum
2+
13
import sqlalchemy as sa
24

35
from h.db import Base
46

57

8+
class ProviderHost(StrEnum):
9+
ORCID = "orcid.org"
10+
11+
612
class UserIdentity(Base):
713
__tablename__ = "user_identity"
814
__table_args__ = (sa.UniqueConstraint("provider", "provider_unique_id"),)

h/schemas/oauth.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
import secrets
3+
from typing import Any, ClassVar, TypedDict
4+
5+
from pyramid.request import Request
6+
7+
from h.schemas.base import JSONSchema, ValidationError
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class OAuthCallbackSchema(JSONSchema):
13+
schema: ClassVar = { # type: ignore[misc]
14+
"type": "object",
15+
"required": ["code"],
16+
"properties": {
17+
"code": {"type": "string"},
18+
"state": {"type": "string"},
19+
"error": {"type": "string"},
20+
"error_description": {"type": "string"},
21+
},
22+
}
23+
24+
25+
class OAuthCallbackData(TypedDict):
26+
code: str
27+
state: str | None
28+
error: str | None
29+
error_description: str | None
30+
31+
32+
class RetrieveOAuthCallbackSchema:
33+
def __init__(self, request: Request) -> None:
34+
self._schema = OAuthCallbackSchema()
35+
self._request = request
36+
37+
def validate(self, data: dict[str, Any]) -> OAuthCallbackData:
38+
if data.get("state") != self._request.session.pop("oauth2_state", None):
39+
raise ValidationError("Invalid oauth state") # noqa: EM101, TRY003
40+
41+
return self._schema.validate(data)
42+
43+
def state_param(self) -> str:
44+
state = secrets.token_hex()
45+
self._request.session["oauth2_state"] = state
46+
return state
47+
48+
49+
class OpenIDTokenSchema(JSONSchema):
50+
schema: ClassVar = { # type: ignore[misc]
51+
"type": "object",
52+
"required": ["access_token", "refresh_token", "expires_in"],
53+
"properties": {
54+
"access_token": {"type": "string"},
55+
"refresh_token": {"type": "string"},
56+
"expires_in": {"type": "integer", "minimum": 1},
57+
"id_token": {"type": "string"},
58+
},
59+
}
60+
61+
62+
class OpenIDTokenData(TypedDict):
63+
access_token: str
64+
refresh_token: str
65+
expires_in: int
66+
id_token: str
67+
68+
69+
class RetrieveOpenIDTokenSchema:
70+
def __init__(self):
71+
self._schema = OpenIDTokenSchema()
72+
73+
def validate(self, data: dict[str, Any]) -> OpenIDTokenData:
74+
return self._schema.validate(data)

h/services/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
BulkLMSStatsService,
1313
)
1414
from h.services.email import EmailService
15+
from h.services.http import HTTPService
1516
from h.services.job_queue import JobQueueService
17+
from h.services.jwt import JWTService
1618
from h.services.mention import MentionService
1719
from h.services.notification import NotificationService
20+
from h.services.openid_client import OpenIDClientService
21+
from h.services.orcid_client import ORCIDClientService
1822
from h.services.subscription import SubscriptionService
1923
from h.services.task_done import TaskDoneService
2024

@@ -178,3 +182,12 @@ def includeme(config): # pragma: no cover # noqa: PLR0915
178182
config.register_service_factory(
179183
"h.services.task_done.factory", iface=TaskDoneService
180184
)
185+
186+
# Authentication-related services
187+
config.register_service_factory("h.services.http.factory", iface=HTTPService)
188+
config.register_service_factory(
189+
"h.services.openid_client.factory", iface=OpenIDClientService
190+
)
191+
config.register_service_factory(
192+
"h.services.orcid_client.factory", iface=ORCIDClientService
193+
)

h/services/openid_client.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
from typing import Any
3+
4+
from h.schemas.oauth import OpenIDTokenData, RetrieveOpenIDTokenSchema
5+
from h.services.http import ExternalRequestError, HTTPService
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class OpenIDTokenError(ExternalRequestError):
11+
"""
12+
A problem with an Open ID token for an external API.
13+
14+
This is raised when we don't have an access token for the current user or
15+
when our access token doesn't work (e.g. because it's expired or been
16+
revoked).
17+
"""
18+
19+
20+
class OpenIDClientService:
21+
def __init__(self, http_service: HTTPService) -> None:
22+
self._http_service = http_service
23+
24+
def get_id_token(
25+
self,
26+
token_url: str,
27+
redirect_uri: str,
28+
auth: tuple[str, str],
29+
authorization_code: str,
30+
) -> str:
31+
data = self._request_openid_data(
32+
token_url=token_url,
33+
auth=auth,
34+
data={
35+
"redirect_uri": redirect_uri,
36+
"grant_type": "authorization_code",
37+
"code": authorization_code,
38+
},
39+
)
40+
return data["id_token"]
41+
42+
def _request_openid_data(
43+
self, token_url: str, data: dict[str, Any], auth: tuple[str, str]
44+
) -> OpenIDTokenData:
45+
response = self._http_service.post(token_url, data=data, auth=auth)
46+
47+
return RetrieveOpenIDTokenSchema().validate(response.json())
48+
49+
50+
def factory(_context, request) -> OpenIDClientService:
51+
return OpenIDClientService(request.find_service(HTTPService))

h/services/orcid_client.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import logging
2+
3+
from sqlalchemy.orm import Session
4+
5+
from h.models import User, UserIdentity
6+
from h.models.user_identity import ProviderHost
7+
from h.services.jwt import JWTService
8+
from h.services.openid_client import OpenIDClientService
9+
from h.services.user import UserService
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class ORCIDClientService:
15+
def __init__( # noqa: PLR0913
16+
self,
17+
session: Session,
18+
host: str,
19+
client_id: str,
20+
client_secret: str,
21+
redirect_uri: str,
22+
openid_client_service: OpenIDClientService,
23+
user_service: UserService,
24+
) -> None:
25+
self._session = session
26+
self._host = host
27+
self._client_id = client_id
28+
self._client_secret = client_secret
29+
self._redirect_uri = redirect_uri
30+
self._openid_client_service = openid_client_service
31+
self._user_service = user_service
32+
33+
def _get_id_token(self, authorization_code: str) -> str:
34+
return self._openid_client_service.get_id_token(
35+
token_url=self.token_url,
36+
redirect_uri=self._redirect_uri,
37+
auth=(self._client_id, self._client_secret),
38+
authorization_code=authorization_code,
39+
)
40+
41+
def get_orcid(self, authorization_code: str) -> str | None:
42+
id_token = self._get_id_token(authorization_code)
43+
decoded_id_token = JWTService.decode_token(id_token, self.key_set_url)
44+
return decoded_id_token.get("sub")
45+
46+
def add_identity(self, user: User, orcid: str) -> None:
47+
identity = UserIdentity(
48+
user=user,
49+
provider=ProviderHost.ORCID,
50+
provider_unique_id=orcid,
51+
)
52+
self._session.add(identity)
53+
54+
@staticmethod
55+
def get_identity(user: User) -> UserIdentity | None:
56+
for identity in user.identities:
57+
if identity.provider == ProviderHost.ORCID:
58+
return identity
59+
return None
60+
61+
@property
62+
def token_url(self) -> str:
63+
return self._api_url("oauth/token")
64+
65+
@property
66+
def key_set_url(self) -> str:
67+
return self._api_url("oauth/jwks")
68+
69+
def orcid_url(self, orcid: str) -> str | None:
70+
return self._api_url(orcid) if orcid else None
71+
72+
def _api_url(self, path: str) -> str:
73+
return f"https://{self._host}/{path}"
74+
75+
76+
def factory(_context, request) -> ORCIDClientService:
77+
settings = request.registry.settings
78+
79+
return ORCIDClientService(
80+
session=request.db,
81+
host=settings["orcid_host"],
82+
client_id=settings["orcid_client_id"],
83+
client_secret=settings["orcid_client_secret"],
84+
redirect_uri=request.route_url("orcid.oauth.callback"),
85+
openid_client_service=request.find_service(OpenIDClientService),
86+
user_service=request.find_service(name="user"),
87+
)

tests/common/fixtures/services.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
import pytest
44

5-
from h.services import MentionService, NotificationService
5+
from h.services import (
6+
HTTPService,
7+
MentionService,
8+
NotificationService,
9+
OpenIDClientService,
10+
)
611
from h.services.analytics import AnalyticsService
712
from h.services.annotation_authority_queue import AnnotationAuthorityQueueService
813
from h.services.annotation_delete import AnnotationDeleteService
@@ -74,6 +79,7 @@
7479
"group_members_service",
7580
"group_service",
7681
"group_update_service",
82+
"http_service",
7783
"links_service",
7884
"list_organizations_service",
7985
"mention_service",
@@ -82,6 +88,7 @@
8288
"nipsa_service",
8389
"notification_service",
8490
"oauth_provider_service",
91+
"openid_client_service",
8592
"organization_service",
8693
"queue_service",
8794
"search_index",
@@ -347,3 +354,13 @@ def email_service(mock_service):
347354
@pytest.fixture
348355
def task_done_service(mock_service):
349356
return mock_service(TaskDoneService)
357+
358+
359+
@pytest.fixture
360+
def http_service(mock_service):
361+
return mock_service(HTTPService)
362+
363+
364+
@pytest.fixture
365+
def openid_client_service(mock_service):
366+
return mock_service(OpenIDClientService)

tests/unit/h/schemas/oauth_test.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
3+
from h.schemas import ValidationError
4+
from h.schemas.oauth import (
5+
RetrieveOAuthCallbackSchema,
6+
RetrieveOpenIDTokenSchema,
7+
)
8+
9+
10+
class TestRetrieveOAuthCallbackSchema:
11+
def test_validate(self, pyramid_request, schema):
12+
data = {"code": "test-code", "state": "test-state"}
13+
pyramid_request.session["oauth2_state"] = data["state"]
14+
15+
result = schema.validate(data)
16+
17+
assert result == data
18+
19+
def test_validate_with_invalid_state(self, pyramid_request, schema):
20+
data = {"code": "test-code", "state": "test-state"}
21+
pyramid_request.session["oauth2_state"] = "different-test-state"
22+
23+
with pytest.raises(ValidationError, match="Invalid oauth state"):
24+
schema.validate(data)
25+
26+
def test_state_param_generates_token(self, pyramid_request, schema, secrets):
27+
state_token = "test-token" # noqa: S105
28+
secrets.token_hex.return_value = state_token
29+
30+
result = schema.state_param()
31+
32+
assert result == state_token
33+
assert pyramid_request.session["oauth2_state"] == state_token
34+
35+
@pytest.fixture
36+
def schema(self, pyramid_request):
37+
return RetrieveOAuthCallbackSchema(pyramid_request)
38+
39+
@pytest.fixture(autouse=True)
40+
def secrets(self, patch):
41+
return patch("h.schemas.oauth.secrets")
42+
43+
44+
class TestRetrieveOpenIDTokenSchema:
45+
def test_validate(self, schema):
46+
data = {
47+
"access_token": "test-access-token",
48+
"refresh_token": "test-refresh-token",
49+
"expires_in": 3600,
50+
"id_token": "test-id-token",
51+
}
52+
53+
result = schema.validate(data)
54+
55+
assert result == data
56+
57+
@pytest.mark.parametrize(
58+
"data,expected_error",
59+
[
60+
(
61+
{"access_token": "test-access-token"},
62+
"^'refresh_token' is a required property, 'expires_in' is a required property$",
63+
),
64+
(
65+
{
66+
"access_token": "test-access-token",
67+
"refresh_token": "test-refresh-token",
68+
},
69+
"^'expires_in' is a required property$",
70+
),
71+
],
72+
)
73+
def test_validate_with_invalid_data(self, data, expected_error, schema):
74+
with pytest.raises(ValidationError, match=expected_error):
75+
schema.validate(data)
76+
77+
@pytest.fixture
78+
def schema(self):
79+
return RetrieveOpenIDTokenSchema()

0 commit comments

Comments
 (0)