diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 3b1d35103..1ab144030 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -198,16 +198,119 @@ def get_expiration_time(self) -> Optional[str]: return None +class AzureManagedIdentityClient: + def __init__(self): + None + + def managed_identity_token(self) -> OAuthClientCredentials: + # Azure IMDS endpoint to get the access token. + # This interface allows any client application running on the Azure VM + # to acquire an access token via HTTP REST calls. + # For more details, see: + # https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + url = "http://169.254.169.254/metadata/identity/oauth2/token" + + resource = "https://management.azure.com/" + # Headers required to access Azure Instance Metadata Service + headers = {"Metadata": "true"} + + # Parameters to specify the resource and API version + params = { + "api-version": "2019-08-01", + "resource": resource + } + + # Make the GET request to fetch the token + response = requests.get(url, headers=headers, params=params) + response.raise_for_status() + + # Check if the request was successful + if response.status_code == 200: + # Return the access token + return self.parse_oauth_token_response(response.text) + + else: + # Handle errors + raise Exception(f"Failed to obtain token: {response.status_code} - {response.text}") + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + if not response: + raise RuntimeError("Empty response from azure managed identity endpoint") + # Parsing the response according to azure managed identity spec + # https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + json_node = json.loads(response) + if 'access_token' not in json_node or not isinstance(json_node['access_token'], str): + raise RuntimeError("Missing 'access_token' field in OAuth token response") + if 'expires_in' not in json_node: + raise RuntimeError("Missing 'expires_in' field in OAuth token response") + try: + expires_in = int(json_node['expires_in']) # Convert to int if it's a string + except ValueError: + raise RuntimeError( + "'expires_in' field must be an integer or a string convertible to integer" + ) + return OAuthClientCredentials( + json_node['access_token'], + expires_in, + int(datetime.now().timestamp()) + ) + + +class AzureManagedIdentityAuthProvider(AuthCredentialProvider): + def __init__(self, + managed_identity_client: AzureManagedIdentityClient, + auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.managed_identity_client = managed_identity_client + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + def add_auth_header(self,session: requests.Session) -> None: + token = self.maybe_refresh_token() + + with self.lock: + session.headers.update( + { + "Authorization": f"Bearer {token.access_token}", + } + ) + + def maybe_refresh_token(self) -> OAuthClientCredentials: + with self.lock: + if self.current_token and not self.needs_refresh(self.current_token): + return self.current_token + new_token = self.managed_identity_client.managed_identity_token() + self.current_token = new_token + return new_token + + def needs_refresh(self, token: OAuthClientCredentials) -> bool: + now = int(time.time()) + expiration_time = token.creation_timestamp + token.expires_in + return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds + + def is_expired(self) -> bool: + return False + + def get_expiration_time(self) -> Optional[str]: + return None + + class AuthCredentialProviderFactory: __oauth_auth_provider_cache : Dict[ DeltaSharingProfile, OAuthClientCredentialsAuthProvider] = {} + __managed_identity_provider_cache : Dict[ + DeltaSharingProfile, + AzureManagedIdentityAuthProvider] = {} + @staticmethod def create_auth_credential_provider(profile: DeltaSharingProfile): if profile.share_credentials_version == 2: if profile.type == "oauth_client_credentials": return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "experimental_managed_identity": + return AuthCredentialProviderFactory.__experimental_managed_identity(profile) elif profile.type == "basic": return AuthCredentialProviderFactory.__auth_basic(profile) elif (profile.share_credentials_version == 1 and @@ -251,3 +354,13 @@ def __auth_bearer_token(profile): @staticmethod def __auth_basic(profile): return BasicAuthProvider(profile.endpoint, profile.username, profile.password) + + @staticmethod + def __experimental_managed_identity(profile): + if profile in AuthCredentialProviderFactory.__managed_identity_provider_cache: + return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] + + managed_identity_client = AzureManagedIdentityClient() + provider = AzureManagedIdentityAuthProvider(managed_identity_client=managed_identity_client) + AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider + return provider diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index a99acfb14..f3a46290f 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -99,6 +99,12 @@ def from_json(json) -> "DeltaSharingProfile": bearer_token=json["bearerToken"], expiration_time=json.get("expirationTime") ) + elif type == "experimental_managed_identity": + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint + ) elif type == "basic": return DeltaSharingProfile( share_credentials_version=share_credentials_version, diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 81dbd1ba9..945915e86 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -20,7 +20,8 @@ BasicAuthProvider, AuthCredentialProviderFactory, OAuthClientCredentialsAuthProvider, - OAuthClientCredentials) + OAuthClientCredentials, AzureManagedIdentityAuthProvider, + AzureManagedIdentityClient) from requests import Session import requests from delta_sharing._internal_auth import BearerTokenAuthProvider @@ -292,3 +293,88 @@ def test_oauth_auth_provider_with_different_profiles(): provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) assert provider1 != provider2 + + +def test_azure_managed_identity_auth_provider_initialization(): + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) + assert provider.current_token is None + + +def test_azure_managed_identity_auth_provider_add_auth_header(): + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + token = OAuthClientCredentials("access-token", 3600, int(datetime.now().timestamp())) + provider.current_token = token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {token.access_token}"} + ) + + +def test_azure_managed_identity_auth_provider_refresh_token(): + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + new_token = OAuthClientCredentials( + "new-token", 3600, int(datetime.now().timestamp())) + provider.current_token = expired_token + + mock_client.managed_identity_token.return_value = new_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {new_token.access_token}"} + ) + mock_client.managed_identity_token.assert_called_once() + + +def test_azure_managed_identity_auth_provider_needs_refresh(): + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + assert provider.needs_refresh(expired_token) + + token_expiring_soon = OAuthClientCredentials( + "expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) + assert provider.needs_refresh(token_expiring_soon) + + valid_token = OAuthClientCredentials( + "valid-token", 600 + 10, int(datetime.now().timestamp())) + assert not provider.needs_refresh(valid_token) + + +def test_azure_managed_identity_auth_provider_is_expired(): + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) + assert not provider.is_expired() + + +def test_azure_managed_identity_auth_provider_get_expiration_time(): + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) + assert provider.get_expiration_time() is None + + +def test_factory_creation_managed_identity(): + profile_managed_identity = DeltaSharingProfile( + share_credentials_version=2, + type="experimental_managed_identity", + endpoint="https://localhost/delta-sharing/" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider( + profile_managed_identity + ) + assert isinstance(provider, AzureManagedIdentityAuthProvider) diff --git a/python/delta_sharing/tests/test_managed_identity_client.py b/python/delta_sharing/tests/test_managed_identity_client.py new file mode 100644 index 000000000..fd36f6a20 --- /dev/null +++ b/python/delta_sharing/tests/test_managed_identity_client.py @@ -0,0 +1,86 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime +from unittest.mock import patch + +import pytest +from requests import Response +from delta_sharing._internal_auth import AzureManagedIdentityClient + + +class MockServer: + def __init__(self): + self.url = "http://169.254.169.254/metadata/identity/oauth2/token" + self.responses = [] + + def add_response(self, status_code, json_data): + response = Response() + response.status_code = status_code + response._content = json_data.encode('utf-8') + self.responses.append(response) + + def get_response(self): + return self.responses.pop(0) + + +@pytest.fixture +def mock_server(): + server = MockServer() + yield server + + +@pytest.mark.parametrize("response_data, expected_expires_in, expected_access_token", [ + ( + '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}', + 3600, + "test-access-token" + ), + ( + '{"access_token": "test-access-token", "expires_in": "3600", "token_type": "bearer"}', + 3600, + "test-access-token" + ) +]) +def test_managed_identity_client_should_parse_token_response_correctly(mock_server, + response_data, + expected_expires_in, + expected_access_token): + mock_server.add_response(200, response_data) + + with patch('requests.get') as mock_get: + mock_get.side_effect = lambda *args, **kwargs: mock_server.get_response() + client = AzureManagedIdentityClient() + + start = datetime.now().timestamp() + token = client.managed_identity_token() + end = datetime.now().timestamp() + + assert token.access_token == expected_access_token + assert token.expires_in == expected_expires_in + assert int(start) <= token.creation_timestamp + assert token.creation_timestamp <= int(end) + + +def test_managed_identity_client_should_handle_500_internal_server_error(mock_server): + mock_server.add_response(500, 'Internal Server Error') + + with patch('requests.get') as mock_get: + mock_get.side_effect = lambda *args, **kwargs: mock_server.get_response() + client = AzureManagedIdentityClient() + try: + client.managed_identity_token() + except Exception as e: + assert e.response.status_code == 500