66
77import platform
88import struct
9- from typing import Tuple , Dict , Optional , Union
9+ from typing import Tuple , Dict , Optional , List
1010from mssql_python .constants import AuthType
1111
12+
1213class AADAuth :
1314 """Handles Azure Active Directory authentication"""
14-
15+
1516 @staticmethod
1617 def get_token_struct (token : str ) -> bytes :
1718 """Convert token to SQL Server compatible format"""
@@ -22,21 +23,21 @@ def get_token_struct(token: str) -> bytes:
2223 def get_token (auth_type : str ) -> bytes :
2324 """Get token using the specified authentication type"""
2425 from azure .identity import (
25- DefaultAzureCredential ,
26- DeviceCodeCredential ,
27- InteractiveBrowserCredential
26+ DefaultAzureCredential ,
27+ DeviceCodeCredential ,
28+ InteractiveBrowserCredential ,
2829 )
2930 from azure .core .exceptions import ClientAuthenticationError
30-
31+
3132 # Mapping of auth types to credential classes
3233 credential_map = {
3334 "default" : DefaultAzureCredential ,
3435 "devicecode" : DeviceCodeCredential ,
3536 "interactive" : InteractiveBrowserCredential ,
3637 }
37-
38+
3839 credential_class = credential_map [auth_type ]
39-
40+
4041 try :
4142 credential = credential_class ()
4243 token = credential .get_token ("https://database.windows.net/.default" ).token
@@ -50,18 +51,21 @@ def get_token(auth_type: str) -> bytes:
5051 ) from e
5152 except Exception as e :
5253 # Catch any other unexpected exceptions
53- raise RuntimeError (f"Failed to create { credential_class .__name__ } : { e } " ) from e
54+ raise RuntimeError (
55+ f"Failed to create { credential_class .__name__ } : { e } "
56+ ) from e
57+
5458
55- def process_auth_parameters (parameters : list ) -> Tuple [list , Optional [str ]]:
59+ def process_auth_parameters (parameters : List [ str ] ) -> Tuple [List [ str ] , Optional [str ]]:
5660 """
5761 Process connection parameters and extract authentication type.
58-
62+
5963 Args:
6064 parameters: List of connection string parameters
61-
65+
6266 Returns:
6367 Tuple[list, Optional[str]]: Modified parameters and authentication type
64-
68+
6569 Raises:
6670 ValueError: If an invalid authentication type is provided
6771 """
@@ -88,7 +92,7 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
8892 # Interactive authentication (browser-based); only append parameter for non-Windows
8993 if platform .system ().lower () == "windows" :
9094 auth_type = None # Let Windows handle AADInteractive natively
91-
95+
9296 elif value_lower == AuthType .DEVICE_CODE .value :
9397 # Device code authentication (for devices without browser)
9498 auth_type = "devicecode"
@@ -99,40 +103,48 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
99103
100104 return modified_parameters , auth_type
101105
102- def remove_sensitive_params (parameters : list ) -> list :
106+
107+ def remove_sensitive_params (parameters : List [str ]) -> List [str ]:
103108 """Remove sensitive parameters from connection string"""
104109 exclude_keys = [
105- "uid=" , "pwd=" , "encrypt=" , "trustservercertificate=" , "authentication="
110+ "uid=" ,
111+ "pwd=" ,
112+ "encrypt=" ,
113+ "trustservercertificate=" ,
114+ "authentication=" ,
106115 ]
107116 return [
108- param for param in parameters
117+ param
118+ for param in parameters
109119 if not any (param .lower ().startswith (exclude ) for exclude in exclude_keys )
110120 ]
111121
122+
112123def get_auth_token (auth_type : str ) -> Optional [bytes ]:
113124 """Get authentication token based on auth type"""
114125 if not auth_type :
115126 return None
116-
127+
117128 # Handle platform-specific logic for interactive auth
118129 if auth_type == "interactive" and platform .system ().lower () == "windows" :
119130 return None # Let Windows handle AADInteractive natively
120-
131+
121132 try :
122133 return AADAuth .get_token (auth_type )
123134 except (ValueError , RuntimeError ):
124135 return None
125136
126- def process_connection_string (connection_string : str ) -> Tuple [str , Optional [Dict ]]:
137+
138+ def process_connection_string (connection_string : str ) -> Tuple [str , Optional [Dict [int , bytes ]]]:
127139 """
128140 Process connection string and handle authentication.
129-
141+
130142 Args:
131143 connection_string: The connection string to process
132-
144+
133145 Returns:
134146 Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
135-
147+
136148 Raises:
137149 ValueError: If the connection string is invalid or empty
138150 """
@@ -145,9 +157,9 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
145157 raise ValueError ("Connection string cannot be empty" )
146158
147159 parameters = connection_string .split (";" )
148-
160+
149161 # Validate that there's at least one valid parameter
150- if not any ('=' in param for param in parameters ):
162+ if not any ("=" in param for param in parameters ):
151163 raise ValueError ("Invalid connection string format" )
152164
153165 modified_parameters , auth_type = process_auth_parameters (parameters )
@@ -158,4 +170,4 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
158170 if token_struct :
159171 return ";" .join (modified_parameters ) + ";" , {1256 : token_struct }
160172
161- return ";" .join (modified_parameters ) + ";" , None
173+ return ";" .join (modified_parameters ) + ";" , None
0 commit comments