Skip to content

Commit 531e9e3

Browse files
committed
Add openid client service
1 parent c2cf667 commit 531e9e3

File tree

6 files changed

+236
-1
lines changed

6 files changed

+236
-1
lines changed

h/schemas/oauth.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import logging
2+
from typing import Any, ClassVar, TypedDict
3+
4+
from h.schemas.base import JSONSchema
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class OpenIDTokenSchema(JSONSchema):
10+
schema: ClassVar = { # type: ignore[misc]
11+
"type": "object",
12+
"required": ["access_token", "expires_in", "id_token"],
13+
"properties": {
14+
"access_token": {"type": "string"},
15+
"refresh_token": {"type": "string"},
16+
"expires_in": {"type": "integer", "minimum": 1},
17+
"id_token": {"type": "string"},
18+
},
19+
}
20+
21+
22+
class OpenIDTokenData(TypedDict):
23+
access_token: str
24+
refresh_token: str | None
25+
expires_in: int
26+
id_token: str
27+
28+
29+
class RetrieveOpenIDTokenSchema:
30+
def __init__(self) -> None:
31+
self._schema = OpenIDTokenSchema()
32+
33+
def validate(self, data: dict[str, Any]) -> OpenIDTokenData:
34+
return self._schema.validate(data)

h/services/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
BulkLMSStatsService,
1313
)
1414
from h.services.email import EmailService
15+
from h.services.http import HTTPService
1516
from h.services.job_queue import JobQueueService
17+
from h.services.jwt import JWTService
1618
from h.services.mention import MentionService
1719
from h.services.notification import NotificationService
20+
from h.services.openid_client import OpenIDClientService
1821
from h.services.subscription import SubscriptionService
1922
from h.services.task_done import TaskDoneService
2023

@@ -178,3 +181,9 @@ def includeme(config): # pragma: no cover # noqa: PLR0915
178181
config.register_service_factory(
179182
"h.services.task_done.factory", iface=TaskDoneService
180183
)
184+
185+
# Authentication-related services
186+
config.register_service_factory("h.services.http.factory", iface=HTTPService)
187+
config.register_service_factory(
188+
"h.services.openid_client.factory", iface=OpenIDClientService
189+
)

h/services/openid_client.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
from typing import Any
3+
4+
from h.schemas.oauth import OpenIDTokenData, RetrieveOpenIDTokenSchema
5+
from h.services.http import ExternalRequestError, HTTPService
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class OpenIDTokenError(ExternalRequestError):
11+
"""
12+
A problem with an Open ID token for an external API.
13+
14+
This is raised when we don't have an access token for the current user or
15+
when our access token doesn't work (e.g. because it's expired or been
16+
revoked).
17+
"""
18+
19+
20+
class OpenIDClientService:
21+
def __init__(self, http_service: HTTPService) -> None:
22+
self._http_service = http_service
23+
24+
def get_id_token(
25+
self,
26+
token_url: str,
27+
redirect_uri: str,
28+
auth: tuple[str, str],
29+
authorization_code: str,
30+
) -> str:
31+
data = self._request_openid_data(
32+
token_url=token_url,
33+
auth=auth,
34+
data={
35+
"redirect_uri": redirect_uri,
36+
"grant_type": "authorization_code",
37+
"code": authorization_code,
38+
},
39+
)
40+
return data["id_token"]
41+
42+
def _request_openid_data(
43+
self, token_url: str, data: dict[str, Any], auth: tuple[str, str]
44+
) -> OpenIDTokenData:
45+
response = self._http_service.post(token_url, data=data, auth=auth)
46+
47+
return RetrieveOpenIDTokenSchema().validate(response.json())
48+
49+
50+
def factory(_context, request) -> OpenIDClientService:
51+
return OpenIDClientService(request.find_service(HTTPService))

tests/common/fixtures/services.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
import pytest
44

5-
from h.services import MentionService, NotificationService
5+
from h.services import (
6+
HTTPService,
7+
MentionService,
8+
NotificationService,
9+
OpenIDClientService,
10+
)
611
from h.services.analytics import AnalyticsService
712
from h.services.annotation_authority_queue import AnnotationAuthorityQueueService
813
from h.services.annotation_delete import AnnotationDeleteService
@@ -74,6 +79,7 @@
7479
"group_members_service",
7580
"group_service",
7681
"group_update_service",
82+
"http_service",
7783
"links_service",
7884
"list_organizations_service",
7985
"mention_service",
@@ -82,6 +88,7 @@
8288
"nipsa_service",
8389
"notification_service",
8490
"oauth_provider_service",
91+
"openid_client_service",
8592
"organization_service",
8693
"queue_service",
8794
"search_index",
@@ -347,3 +354,13 @@ def email_service(mock_service):
347354
@pytest.fixture
348355
def task_done_service(mock_service):
349356
return mock_service(TaskDoneService)
357+
358+
359+
@pytest.fixture
360+
def http_service(mock_service):
361+
return mock_service(HTTPService)
362+
363+
364+
@pytest.fixture
365+
def openid_client_service(mock_service):
366+
return mock_service(OpenIDClientService)

tests/unit/h/schemas/oauth_test.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import pytest
2+
3+
from h.schemas import ValidationError
4+
from h.schemas.oauth import RetrieveOpenIDTokenSchema
5+
6+
7+
class TestRetrieveOpenIDTokenSchema:
8+
def test_validate(self, schema):
9+
data = {
10+
"access_token": "test_access_token",
11+
"refresh_token": "test_refresh_token",
12+
"expires_in": 3600,
13+
"id_token": "test_id_token",
14+
}
15+
16+
result = schema.validate(data)
17+
18+
assert result == data
19+
20+
@pytest.mark.parametrize(
21+
"data,expected_error",
22+
[
23+
(
24+
{"access_token": "test_access_token"},
25+
"^'expires_in' is a required property, 'id_token' is a required property$",
26+
),
27+
(
28+
{
29+
"access_token": "test_access_token",
30+
"expires_in": 3600,
31+
},
32+
"^'id_token' is a required property$",
33+
),
34+
],
35+
)
36+
def test_validate_with_invalid_data(self, data, expected_error, schema):
37+
with pytest.raises(ValidationError, match=expected_error):
38+
schema.validate(data)
39+
40+
@pytest.fixture
41+
def schema(self):
42+
return RetrieveOpenIDTokenSchema()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from unittest.mock import sentinel
2+
3+
import pytest
4+
5+
from h.services.exceptions import ExternalRequestError
6+
from h.services.openid_client import OpenIDClientService, factory
7+
8+
9+
class TestOpenIDClientService:
10+
def test_get_id_token(
11+
self, svc, http_service, passed_args, RetrieveOpenIDTokenSchema, id_token_data
12+
):
13+
http_response = http_service.post.return_value
14+
http_response.json.return_value = id_token_data
15+
RetrieveOpenIDTokenSchema.return_value.validate.return_value = id_token_data
16+
17+
result = svc.get_id_token(**passed_args)
18+
19+
http_service.post.assert_called_once_with(
20+
passed_args["token_url"],
21+
data={
22+
"redirect_uri": passed_args["redirect_uri"],
23+
"grant_type": "authorization_code",
24+
"code": passed_args["authorization_code"],
25+
},
26+
auth=passed_args["auth"],
27+
)
28+
RetrieveOpenIDTokenSchema.return_value.validate.assert_called_once_with(
29+
http_response.json.return_value
30+
)
31+
assert result == id_token_data["id_token"]
32+
33+
def test_get_id_token_raises_if_the_request_fails(
34+
self, svc, http_service, passed_args
35+
):
36+
http_service.post.side_effect = ExternalRequestError(
37+
request=sentinel.err_request, response=sentinel.err_response
38+
)
39+
40+
with pytest.raises(ExternalRequestError) as exc_info:
41+
svc.get_id_token(**passed_args)
42+
43+
assert exc_info.value.request == sentinel.err_request
44+
assert exc_info.value.response == sentinel.err_response
45+
46+
@pytest.fixture
47+
def id_token_data(self):
48+
return {
49+
"access_token": "test_access_token",
50+
"refresh_token": "test_refresh_token",
51+
"expires_in": 3600,
52+
"id_token": "test_id_token",
53+
}
54+
55+
@pytest.fixture
56+
def passed_args(self):
57+
return {
58+
"token_url": sentinel.token_url,
59+
"redirect_uri": sentinel.redirect_uri,
60+
"auth": sentinel.auth,
61+
"authorization_code": sentinel.authorization_code,
62+
}
63+
64+
@pytest.fixture
65+
def svc(self, http_service):
66+
return OpenIDClientService(http_service)
67+
68+
@pytest.fixture(autouse=True)
69+
def RetrieveOpenIDTokenSchema(self, patch):
70+
return patch("h.services.openid_client.RetrieveOpenIDTokenSchema")
71+
72+
73+
class TestFactory:
74+
def test_it(self, pyramid_request, http_service, OpenIDClientService):
75+
svc = factory(sentinel.context, pyramid_request)
76+
77+
OpenIDClientService.assert_called_once_with(http_service)
78+
assert svc == OpenIDClientService.return_value
79+
80+
@pytest.fixture(autouse=True)
81+
def OpenIDClientService(self, patch):
82+
return patch("h.services.openid_client.OpenIDClientService")

0 commit comments

Comments
 (0)