diff --git a/data-pipeline/data_pipeline/aisr.py b/data-pipeline/data_pipeline/aisr.py index b317395..1cc1a4d 100644 --- a/data-pipeline/data_pipeline/aisr.py +++ b/data-pipeline/data_pipeline/aisr.py @@ -83,8 +83,13 @@ def _get_code_from_response(response: requests.Response) -> str: """ location = response.headers.get("Location") if location: - query_dict = parse_qs(urlparse(location).query) - return query_dict["code"][0] + parsed_url = urlparse(location) + fragment = parsed_url.fragment + fragment_dict = parse_qs(fragment) + code_list = fragment_dict.get("code") + if code_list: + return code_list[0] + raise CodeNotFoundError("Code not found in response fragment.") raise CodeNotFoundError("Code not found in response Location header.") diff --git a/data-pipeline/tests/unit/test_aisr.py b/data-pipeline/tests/unit/test_aisr.py index 4a88874..35ac5dd 100644 --- a/data-pipeline/tests/unit/test_aisr.py +++ b/data-pipeline/tests/unit/test_aisr.py @@ -28,7 +28,7 @@ def test_extract_code_from_auth_response_headers(fastapi_server): test_realm_url = f"{fastapi_server}/auth/realms/idepc-aisr-realm" mock_response = Mock() mock_response.status_code = 302 - mock_response.headers = {"Location": f"{test_realm_url}?code=test_code"} + mock_response.headers = {"Location": f"{test_realm_url}#code=test_code"} code = _get_code_from_response(mock_response)