Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

d2o python authorization code grant #572

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions python/delta_sharing/_internal_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def create_auth_credential_provider(profile: DeltaSharingProfile):
if profile.share_credentials_version == 2:
if profile.type == "oauth_client_credentials":
return AuthCredentialProviderFactory.__oauth_client_credentials(profile)
elif profile.type == "oauth_authorization_code":
return AuthCredentialProviderFactory.__oauth_authorization_code(profile)
elif profile.type == "basic":
return AuthCredentialProviderFactory.__auth_basic(profile)
elif (profile.share_credentials_version == 1 and
Expand Down Expand Up @@ -224,10 +226,231 @@ def __oauth_client_credentials(profile):
AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider
return provider

@staticmethod
def __oauth_authorization_code(profile):
# Once a clientId/clientSecret is exchanged for an accessToken,
# the accessToken can be reused until it expires.
# The Python client re-creates DeltaSharingClient for different requests.
# To ensure the OAuth access_token is reused,
# we keep a mapping from profile -> OAuthClientCredentialsAuthProvider.
# This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile,
# ensuring the access_token can be reused.
if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache:
return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile]


provider = OAuthU2MAuthCredentialProvider(
client_id=profile.client_id,
client_secret=profile.client_secret,
token_url=profile.token_url,
auth_url=profile.auth_url,
scope=profile.scope,
redirect_uri=profile.redirect_uri
)
AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider
return provider

@staticmethod
def __auth_bearer_token(profile):
return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time)

@staticmethod
def __auth_basic(profile):
return BasicAuthProvider(profile.endpoint, profile.username, profile.password)




import time
import json
from typing import Optional
from abc import ABC, abstractmethod
from urllib.parse import urlencode, urlparse, parse_qs
from urllib.request import Request, urlopen
from http.server import BaseHTTPRequestHandler, HTTPServer
import threading

class AuthCredentialProvider(ABC):
@abstractmethod
def add_auth_header(self, session) -> None:
pass

def is_expired(self) -> bool:
return False

@abstractmethod
def get_expiration_time(self) -> Optional[str]:
return None

import time
import json
import base64
import hashlib
import secrets
import webbrowser
from typing import Optional
from urllib.parse import urlencode, urlparse, parse_qs
from urllib.request import Request, urlopen
from http.server import BaseHTTPRequestHandler, HTTPServer
import threading
from abc import ABC, abstractmethod


class AuthCredentialProvider(ABC):
@abstractmethod
def add_auth_header(self, session) -> None:
pass

def is_expired(self) -> bool:
return False

@abstractmethod
def get_expiration_time(self) -> Optional[str]:
return None


class OAuthU2MAuthCredentialProvider(AuthCredentialProvider):
def __init__(self, client_id: str, client_secret: str, token_url: str, auth_url: str, scope: str, redirect_uri: str, port_range=range(8080, 8081)):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.auth_url = auth_url
self.scope = scope
self.redirect_uri = redirect_uri
self.access_token = None
self.token_expiry = None
self.server = None
self.authorization_code = None
self.port_range = port_range

def start_http_server(self, port: int):
"""
Starts an HTTP server to listen for the OAuth provider's redirect with the authorization code.
"""
class OAuthCallbackHandler(BaseHTTPRequestHandler):
def do_GET(self):
# Parse the query parameters from the redirect URI
query_components = parse_qs(urlparse(self.path).query)
code = query_components.get('code', [None])[0]

if code:
self.server.provider.authorization_code = code
self.send_response(200)
self.end_headers()
self.wfile.write(b"OAuth Authentication successful! You can close this window.")
else:
self.send_response(400)
self.end_headers()
self.wfile.write(b"Error: Missing authorization code.")
def log_message(self, format, *args):
pass # Suppress log messages

OAuthCallbackHandler.server = self

self.server = HTTPServer(('localhost', port), OAuthCallbackHandler)
self.server.provider = self
#print(f"Starting HTTP server at http://localhost:{port}")
threading.Thread(target=self.server.serve_forever).start()

def stop_http_server(self):
"""
Stops the HTTP server if it's running.
"""
if self.server:
self.server.shutdown()
self.server.server_close()
#print("HTTP server stopped.")

def _get_access_token(self) -> None:
"""
Fetches a new access token from the OAuth provider using the authorization code.
"""
if not self.authorization_code:
self._get_authorization_code()

token_data = {
'grant_type': 'authorization_code',
'code': self.authorization_code,
'redirect_uri': self.redirect_uri,
'client_id': self.client_id,
'code_verifier': self.code_verifier # Include the correct code_verifier here
}

token_request = Request(self.token_url, data=urlencode(token_data).encode('utf-8'))
token_request.add_header('Content-Type', 'application/x-www-form-urlencoded')
#print(token_data)

import urllib.request # Import the urllib.request module for making HTTP requests
import urllib.error # Import the urllib.error module for handling HTTP errors
try:
with urlopen(token_request) as response:
token_response = json.loads(response.read().decode('utf-8'))
self.access_token = token_response.get('access_token')
expires_in = token_response.get('expires_in')
self.token_expiry = time.time() + expires_in

except urllib.error.HTTPError as e:
error_response = e.read().decode('utf-8')
print(f"HTTP Error: {e.code} - {e.reason}\nResponse: {error_response}")

def _get_authorization_code(self):
"""
Directs the user to the OAuth provider's authorization URL to get the authorization code.
"""
import base64
import hashlib
import secrets

# Generate code_verifier and code_challenge
self.code_verifier = secrets.token_urlsafe(128)
self.code_challenge = base64.urlsafe_b64encode(hashlib.sha256(self.code_verifier.encode()).digest()).rstrip(b'=').decode('utf-8')

auth_params = urlencode({
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'scope': self.scope,
'code_challenge': self.code_challenge,
'code_challenge_method': 'S256'
})
auth_url = f"{self.auth_url}?{auth_params}"
#print(f"Initiating U2M OAuth: {auth_url}")


webbrowser.open(auth_url) # Open the browser for user authorization

for port in self.port_range:
try:
self.start_http_server(port)
self.redirect_uri = f"http://localhost:{port}"
while not self.authorization_code:
time.sleep(1)
break
except OSError:
continue

self.stop_http_server()

def add_auth_header(self, session) -> None:
"""
Adds the OAuth token to the request headers if needed.
"""
if not self.access_token or self.is_expired():
self._get_access_token()

session.headers['Authorization'] = f'Bearer {self.access_token}'

def is_expired(self) -> bool:
"""
Checks if the current access token is expired.
"""
return self.token_expiry is None or time.time() >= self.token_expiry

def get_expiration_time(self) -> Optional[str]:
"""
Returns the expiration time of the current access token.
"""
if self.token_expiry:
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.token_expiry))
return None

16 changes: 16 additions & 0 deletions python/delta_sharing/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class DeltaSharingProfile:
password: Optional[str] = None
scope: Optional[str] = None

token_url: Optional[str] = None
auth_url: Optional[str] = None
scope: Optional[str] = None
redirect_uri: Optional[str] = None

def __post_init__(self):
if self.share_credentials_version > DeltaSharingProfile.CURRENT:
raise ValueError(
Expand Down Expand Up @@ -107,6 +112,17 @@ def from_json(json) -> "DeltaSharingProfile":
username=json["username"],
password=json["password"],
)
elif type == "oauth_authorization_code":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
client_id=json["clientId"],
redirect_uri=json["redirectUri"],
token_url=json["tokenEndpoint"],
auth_url=json["authorizeEndpoint"],
scope=json.get("scope"),
)
else:
raise ValueError(
f"The current release does not supports {type} type. "
Expand Down
3 changes: 2 additions & 1 deletion python/delta_sharing/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10):

self._session.headers.update(
{
"User-Agent": DataSharingRestClient.USER_AGENT,
"Custom-Header-Recipient-Id": "7ccbb5da-b1b1-4519-ae53-190db7988199",
"User-Agent": "Python-Delta-Sharing-Client"
}
)

Expand Down
39 changes: 39 additions & 0 deletions python/delta_sharing/tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import delta_sharing

# Point to the profile file. It can be a file on the local file system or a file on a remote storage.
profile_file = "/Users/Moe.Derakhshani/Documents/oauth/demo/u2m.share"

# Create a SharingClient.
client = delta_sharing.SharingClient(profile_file)
#
# List all shared tables.
tables = client.list_all_tables()

print(tables)


#
# # Create a url to access a shared table.
# # A table path is the profile file path following with `#` and the fully qualified name of a table
# # (`<share-name>.<schema-name>.<table-name>`).
table_url = profile_file + "#demo-d2o-identity-federation.my_schema.my_table"

# Fetch 10 rows from a table and convert it to a Pandas DataFrame. This can be used to read sample data
# from a table that cannot fit in the memory.
df = delta_sharing.load_as_pandas(table_url, limit=10)

print(df)

#
# Load a table as a Pandas DataFrame. This can be used to process tables that can fit in the memory.
delta_sharing.load_as_pandas(table_url)

# Load a table as a Pandas DataFrame explicitly using Delta Format
#delta_sharing.load_as_pandas(table_url, use_delta_format = True)

# # If the code is running with PySpark, you can use `load_as_spark` to load the table as a Spark DataFrame.
# delta_sharing.load_as_spark(table_url)



print("DONE")
Loading