Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Secret-less Azure Managed Identity in Python Delta-Sharing Client #633

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions python/delta_sharing/_internal_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,119 @@ def get_expiration_time(self) -> Optional[str]:
return None


class AzureManagedIdentityClient:
def __init__(self):
None

def managed_identity_token(self) -> OAuthClientCredentials:
# Azure IMDS endpoint to get the access token.
# This interface allows any client application running on the Azure VM
# to acquire an access token via HTTP REST calls.
# For more details, see:
# https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
url = "http://169.254.169.254/metadata/identity/oauth2/token"

resource = "https://management.azure.com/"
# Headers required to access Azure Instance Metadata Service
headers = {"Metadata": "true"}

# Parameters to specify the resource and API version
params = {
"api-version": "2019-08-01",
"resource": resource
}

# Make the GET request to fetch the token
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()

# Check if the request was successful
if response.status_code == 200:
# Return the access token
return self.parse_oauth_token_response(response.text)

else:
# Handle errors
raise Exception(f"Failed to obtain token: {response.status_code} - {response.text}")

def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials:
if not response:
raise RuntimeError("Empty response from azure managed identity endpoint")
# Parsing the response according to azure managed identity spec
# https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
json_node = json.loads(response)
if 'access_token' not in json_node or not isinstance(json_node['access_token'], str):
raise RuntimeError("Missing 'access_token' field in OAuth token response")
if 'expires_in' not in json_node:
raise RuntimeError("Missing 'expires_in' field in OAuth token response")
try:
expires_in = int(json_node['expires_in']) # Convert to int if it's a string
except ValueError:
raise RuntimeError(
"'expires_in' field must be an integer or a string convertible to integer"
)
return OAuthClientCredentials(
json_node['access_token'],
expires_in,
int(datetime.now().timestamp())
)


class AzureManagedIdentityAuthProvider(AuthCredentialProvider):
def __init__(self,
managed_identity_client: AzureManagedIdentityClient,
auth_config: AuthConfig = AuthConfig()):
self.auth_config = auth_config
self.managed_identity_client = managed_identity_client
self.current_token: Optional[OAuthClientCredentials] = None
self.lock = threading.RLock()

def add_auth_header(self,session: requests.Session) -> None:
token = self.maybe_refresh_token()

with self.lock:
session.headers.update(
{
"Authorization": f"Bearer {token.access_token}",
}
)

def maybe_refresh_token(self) -> OAuthClientCredentials:
with self.lock:
if self.current_token and not self.needs_refresh(self.current_token):
return self.current_token
new_token = self.managed_identity_client.managed_identity_token()
self.current_token = new_token
return new_token

def needs_refresh(self, token: OAuthClientCredentials) -> bool:
now = int(time.time())
expiration_time = token.creation_timestamp + token.expires_in
return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds

def is_expired(self) -> bool:
return False

def get_expiration_time(self) -> Optional[str]:
return None


class AuthCredentialProviderFactory:
__oauth_auth_provider_cache : Dict[
DeltaSharingProfile,
OAuthClientCredentialsAuthProvider] = {}

__managed_identity_provider_cache : Dict[
DeltaSharingProfile,
AzureManagedIdentityAuthProvider] = {}

@staticmethod
def create_auth_credential_provider(profile: DeltaSharingProfile):
if profile.share_credentials_version == 2:
if profile.type == "oauth_client_credentials":
return AuthCredentialProviderFactory.__oauth_client_credentials(profile)
elif profile.type == "experimental_managed_identity":
return AuthCredentialProviderFactory.__experimental_managed_identity(profile)
elif profile.type == "basic":
return AuthCredentialProviderFactory.__auth_basic(profile)
elif (profile.share_credentials_version == 1 and
Expand Down Expand Up @@ -251,3 +354,13 @@ def __auth_bearer_token(profile):
@staticmethod
def __auth_basic(profile):
return BasicAuthProvider(profile.endpoint, profile.username, profile.password)

@staticmethod
def __experimental_managed_identity(profile):
if profile in AuthCredentialProviderFactory.__managed_identity_provider_cache:
return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile]

managed_identity_client = AzureManagedIdentityClient()
provider = AzureManagedIdentityAuthProvider(managed_identity_client=managed_identity_client)
AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider
return provider
6 changes: 6 additions & 0 deletions python/delta_sharing/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def from_json(json) -> "DeltaSharingProfile":
bearer_token=json["bearerToken"],
expiration_time=json.get("expirationTime")
)
elif type == "experimental_managed_identity":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint
)
elif type == "basic":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
Expand Down
88 changes: 87 additions & 1 deletion python/delta_sharing/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
BasicAuthProvider,
AuthCredentialProviderFactory,
OAuthClientCredentialsAuthProvider,
OAuthClientCredentials)
OAuthClientCredentials, AzureManagedIdentityAuthProvider,
AzureManagedIdentityClient)
from requests import Session
import requests
from delta_sharing._internal_auth import BearerTokenAuthProvider
Expand Down Expand Up @@ -292,3 +293,88 @@ def test_oauth_auth_provider_with_different_profiles():
provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2)

assert provider1 != provider2


def test_azure_managed_identity_auth_provider_initialization():
mock_client = MagicMock(spec=AzureManagedIdentityClient)
provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client)
assert provider.current_token is None


def test_azure_managed_identity_auth_provider_add_auth_header():
mock_client = MagicMock(spec=AzureManagedIdentityClient)
provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client)
mock_session = MagicMock(spec=Session)
mock_session.headers = MagicMock()

token = OAuthClientCredentials("access-token", 3600, int(datetime.now().timestamp()))
provider.current_token = token

provider.add_auth_header(mock_session)

mock_session.headers.update.assert_called_once_with(
{"Authorization": f"Bearer {token.access_token}"}
)


def test_azure_managed_identity_auth_provider_refresh_token():
mock_client = MagicMock(spec=AzureManagedIdentityClient)
provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client)
mock_session = MagicMock(spec=Session)
mock_session.headers = MagicMock()

expired_token = OAuthClientCredentials(
"expired-token", 1, int(datetime.now().timestamp()) - 3600)
new_token = OAuthClientCredentials(
"new-token", 3600, int(datetime.now().timestamp()))
provider.current_token = expired_token

mock_client.managed_identity_token.return_value = new_token

provider.add_auth_header(mock_session)

mock_session.headers.update.assert_called_once_with(
{"Authorization": f"Bearer {new_token.access_token}"}
)
mock_client.managed_identity_token.assert_called_once()


def test_azure_managed_identity_auth_provider_needs_refresh():
mock_client = MagicMock(spec=AzureManagedIdentityClient)
provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client)

expired_token = OAuthClientCredentials(
"expired-token", 1, int(datetime.now().timestamp()) - 3600)
assert provider.needs_refresh(expired_token)

token_expiring_soon = OAuthClientCredentials(
"expiring-soon-token", 600 - 5, int(datetime.now().timestamp()))
assert provider.needs_refresh(token_expiring_soon)

valid_token = OAuthClientCredentials(
"valid-token", 600 + 10, int(datetime.now().timestamp()))
assert not provider.needs_refresh(valid_token)


def test_azure_managed_identity_auth_provider_is_expired():
mock_client = MagicMock(spec=AzureManagedIdentityClient)
provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client)
assert not provider.is_expired()


def test_azure_managed_identity_auth_provider_get_expiration_time():
mock_client = MagicMock(spec=AzureManagedIdentityClient)
provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client)
assert provider.get_expiration_time() is None


def test_factory_creation_managed_identity():
profile_managed_identity = DeltaSharingProfile(
share_credentials_version=2,
type="experimental_managed_identity",
endpoint="https://localhost/delta-sharing/"
)
provider = AuthCredentialProviderFactory.create_auth_credential_provider(
profile_managed_identity
)
assert isinstance(provider, AzureManagedIdentityAuthProvider)
86 changes: 86 additions & 0 deletions python/delta_sharing/tests/test_managed_identity_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Copyright (C) 2021 The Delta Lake Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from unittest.mock import patch

import pytest
from requests import Response
from delta_sharing._internal_auth import AzureManagedIdentityClient


class MockServer:
def __init__(self):
self.url = "http://169.254.169.254/metadata/identity/oauth2/token"
self.responses = []

def add_response(self, status_code, json_data):
response = Response()
response.status_code = status_code
response._content = json_data.encode('utf-8')
self.responses.append(response)

def get_response(self):
return self.responses.pop(0)


@pytest.fixture
def mock_server():
server = MockServer()
yield server


@pytest.mark.parametrize("response_data, expected_expires_in, expected_access_token", [
(
'{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}',
3600,
"test-access-token"
),
(
'{"access_token": "test-access-token", "expires_in": "3600", "token_type": "bearer"}',
3600,
"test-access-token"
)
])
def test_managed_identity_client_should_parse_token_response_correctly(mock_server,
response_data,
expected_expires_in,
expected_access_token):
mock_server.add_response(200, response_data)

with patch('requests.get') as mock_get:
mock_get.side_effect = lambda *args, **kwargs: mock_server.get_response()
client = AzureManagedIdentityClient()

start = datetime.now().timestamp()
token = client.managed_identity_token()
end = datetime.now().timestamp()

assert token.access_token == expected_access_token
assert token.expires_in == expected_expires_in
assert int(start) <= token.creation_timestamp
assert token.creation_timestamp <= int(end)


def test_managed_identity_client_should_handle_500_internal_server_error(mock_server):
mock_server.add_response(500, 'Internal Server Error')

with patch('requests.get') as mock_get:
mock_get.side_effect = lambda *args, **kwargs: mock_server.get_response()
client = AzureManagedIdentityClient()
try:
client.managed_identity_token()
except Exception as e:
assert e.response.status_code == 500
Loading