Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 135 additions & 114 deletions mssql_python/__init__.py

Large diffs are not rendered by default.

64 changes: 38 additions & 26 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import platform
import struct
from typing import Tuple, Dict, Optional, Union
from typing import Tuple, Dict, Optional, List
from mssql_python.constants import AuthType


class AADAuth:
"""Handles Azure Active Directory authentication"""

@staticmethod
def get_token_struct(token: str) -> bytes:
"""Convert token to SQL Server compatible format"""
Expand All @@ -22,21 +23,21 @@ def get_token_struct(token: str) -> bytes:
def get_token(auth_type: str) -> bytes:
"""Get token using the specified authentication type"""
from azure.identity import (
DefaultAzureCredential,
DeviceCodeCredential,
InteractiveBrowserCredential
DefaultAzureCredential,
DeviceCodeCredential,
InteractiveBrowserCredential,
)
from azure.core.exceptions import ClientAuthenticationError

# Mapping of auth types to credential classes
credential_map = {
"default": DefaultAzureCredential,
"devicecode": DeviceCodeCredential,
"interactive": InteractiveBrowserCredential,
}

credential_class = credential_map[auth_type]

try:
credential = credential_class()
token = credential.get_token("https://database.windows.net/.default").token
Expand All @@ -50,18 +51,21 @@ def get_token(auth_type: str) -> bytes:
) from e
except Exception as e:
# Catch any other unexpected exceptions
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e
raise RuntimeError(
f"Failed to create {credential_class.__name__}: {e}"
) from e


def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]:
"""
Process connection parameters and extract authentication type.

Args:
parameters: List of connection string parameters

Returns:
Tuple[list, Optional[str]]: Modified parameters and authentication type

Raises:
ValueError: If an invalid authentication type is provided
"""
Expand All @@ -88,7 +92,7 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
# Interactive authentication (browser-based); only append parameter for non-Windows
if platform.system().lower() == "windows":
auth_type = None # Let Windows handle AADInteractive natively

elif value_lower == AuthType.DEVICE_CODE.value:
# Device code authentication (for devices without browser)
auth_type = "devicecode"
Expand All @@ -99,40 +103,48 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:

return modified_parameters, auth_type

def remove_sensitive_params(parameters: list) -> list:

def remove_sensitive_params(parameters: List[str]) -> List[str]:
"""Remove sensitive parameters from connection string"""
exclude_keys = [
"uid=", "pwd=", "encrypt=", "trustservercertificate=", "authentication="
"uid=",
"pwd=",
"encrypt=",
"trustservercertificate=",
"authentication=",
]
return [
param for param in parameters
param
for param in parameters
if not any(param.lower().startswith(exclude) for exclude in exclude_keys)
]


def get_auth_token(auth_type: str) -> Optional[bytes]:
"""Get authentication token based on auth type"""
if not auth_type:
return None

# Handle platform-specific logic for interactive auth
if auth_type == "interactive" and platform.system().lower() == "windows":
return None # Let Windows handle AADInteractive natively

try:
return AADAuth.get_token(auth_type)
except (ValueError, RuntimeError):
return None

def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]:

def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict[int, bytes]]]:
"""
Process connection string and handle authentication.

Args:
connection_string: The connection string to process

Returns:
Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed

Raises:
ValueError: If the connection string is invalid or empty
"""
Expand All @@ -145,9 +157,9 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
raise ValueError("Connection string cannot be empty")

parameters = connection_string.split(";")

# Validate that there's at least one valid parameter
if not any('=' in param for param in parameters):
if not any("=" in param for param in parameters):
raise ValueError("Invalid connection string format")

modified_parameters, auth_type = process_auth_parameters(parameters)
Expand All @@ -158,4 +170,4 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
if token_struct:
return ";".join(modified_parameters) + ";", {1256: token_struct}

return ";".join(modified_parameters) + ";", None
return ";".join(modified_parameters) + ";", None
121 changes: 0 additions & 121 deletions mssql_python/bcp_options.py

This file was deleted.

Loading