Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class SnowflakeAuthenticator:
def __init__(
self,
authenticator: Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"],
api_key: Optional[Secret] = None,
private_key_file: Optional[Secret] = None,
private_key_file_pwd: Optional[Secret] = None,
oauth_client_id: Optional[Secret] = None,
oauth_client_secret: Optional[Secret] = None,
api_key: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008
private_key_file: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008
private_key_file_pwd: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE_PWD", strict=False), # noqa: B008
oauth_client_id: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_CLIENT_ID", strict=False), # noqa: B008
oauth_client_secret: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_CLIENT_SECRET", strict=False), # noqa: B008
oauth_token_request_url: Optional[str] = None,
oauth_authorization_url: Optional[str] = None,
) -> None:
Expand All @@ -69,11 +69,11 @@ def __init__(
:param oauth_authorization_url: OAuth authorization URL.
"""
self.authenticator = authenticator
self.api_key = api_key
self.private_key_file = private_key_file
self.private_key_file_pwd = private_key_file_pwd
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.api_key = api_key.resolve_value() if api_key else None
self.private_key_file = private_key_file.resolve_value() if private_key_file else None
self.private_key_file_pwd = private_key_file_pwd.resolve_value() if private_key_file_pwd else None
self.oauth_client_id = oauth_client_id.resolve_value() if oauth_client_id else None
self.oauth_client_secret = oauth_client_secret.resolve_value() if oauth_client_secret else None
self.oauth_token_request_url = oauth_token_request_url
self.oauth_authorization_url = oauth_authorization_url

Expand All @@ -96,29 +96,6 @@ def validate_auth_params(self) -> None:
elif self.authenticator == AUTH_SNOWFLAKE:
if not self.api_key:
raise ValueError(ERROR_API_KEY_REQUIRED)
try:
api_key_value = self.api_key.resolve_value()
if not api_key_value:
raise ValueError(ERROR_API_KEY_REQUIRED)
except Exception as e:
msg = f"Failed to resolve api_key: {e!s}"
raise ValueError(msg) from e

def resolve_secret_value(self, value: Optional[Secret]) -> Optional[str]:
"""
Safely resolves a Secret value.

:param value: Secret to resolve.
:returns: Resolved string value or None.
:raises ValueError: If secret resolution fails.
"""
if value is None:
return None
try:
return value.resolve_value()
except Exception as e:
msg = f"Failed to resolve secret value: {e!s}"
raise ValueError(msg) from e

def read_private_key_content(self) -> Optional[str]:
"""
Expand All @@ -131,13 +108,9 @@ def read_private_key_content(self) -> Optional[str]:
return None

try:
private_key_path = self.resolve_secret_value(self.private_key_file)
if not private_key_path:
return None

key_path = Path(private_key_path)
key_path = Path(self.private_key_file)
if not key_path.exists():
msg = f"Private key file not found: {private_key_path}"
msg = f"Private key file not found: {self.private_key_file}"
raise PrivateKeyReadError(msg)

return key_path.read_text()
Expand Down Expand Up @@ -171,14 +144,11 @@ def _build_jwt_auth_params(self, user: Optional[str] = None) -> list[str]:
except Exception as e:
logger.warning(f"Failed to read private key content, falling back to file path: {e!s}")
# Fallback to file path (though ADBC may not support this)
private_key_path = self.resolve_secret_value(self.private_key_file)
params.append(f"private_key_file={private_key_path}")
params.append(f"private_key_file={self.private_key_file}")

# Only include password parameter if it's actually set
if self.private_key_file_pwd:
private_key_pwd = self.resolve_secret_value(self.private_key_file_pwd)
if private_key_pwd: # Only add if not empty string
params.append(f"{ADBC_PARAM_JWT_KEY_PASSWORD}={private_key_pwd}")
if self.private_key_file_pwd: # Only add if not empty string
params.append(f"{ADBC_PARAM_JWT_KEY_PASSWORD}={self.private_key_file_pwd}")

return params

Expand All @@ -191,11 +161,9 @@ def _build_oauth_auth_params(self) -> list[str]:
params = [f"authenticator={self.authenticator}"]

if self.oauth_client_id:
client_id = self.resolve_secret_value(self.oauth_client_id)
params.append(f"oauth_client_id={client_id}")
params.append(f"oauth_client_id={self.oauth_client_id}")
if self.oauth_client_secret:
client_secret = self.resolve_secret_value(self.oauth_client_secret)
params.append(f"oauth_client_secret={client_secret}")
params.append(f"oauth_client_secret={self.oauth_client_secret}")
if self.oauth_token_request_url:
params.append(f"oauth_token_request_url={self.oauth_token_request_url}")
if self.oauth_authorization_url:
Expand Down Expand Up @@ -225,7 +193,7 @@ def get_password_for_uri(self) -> Optional[str]:
:raises ValueError: If secret resolution fails.
"""
if self.authenticator == AUTH_SNOWFLAKE and self.api_key:
return self.resolve_secret_value(self.api_key)
return self.api_key
return None

def create_masked_params(self, params: list) -> list[str]:
Expand All @@ -241,16 +209,12 @@ def create_masked_params(self, params: list) -> list[str]:
masked_param = param

# Mask private key password
if self.private_key_file_pwd:
private_key_pwd = self.resolve_secret_value(self.private_key_file_pwd)
if private_key_pwd and private_key_pwd in param:
masked_param = param.replace(private_key_pwd, "***REDACTED***")
if self.private_key_file_pwd and self.private_key_file_pwd in param:
masked_param = param.replace(self.private_key_file_pwd, "***REDACTED***")

# Mask OAuth client secret
if self.oauth_client_secret:
client_secret = self.resolve_secret_value(self.oauth_client_secret)
if client_secret and client_secret in param:
masked_param = masked_param.replace(client_secret, "***REDACTED***")
if self.oauth_client_secret and self.oauth_client_secret in param:
masked_param = masked_param.replace(self.oauth_client_secret, "***REDACTED***")

masked_params.append(masked_param)

Expand All @@ -276,23 +240,17 @@ def test_connection(self, user: str, account: str, database: Optional[str] = Non
connection_params["database"] = database

if self.authenticator == AUTH_SNOWFLAKE:
password = self.resolve_secret_value(self.api_key)
if password:
connection_params["password"] = password
if self.api_key:
connection_params["password"] = self.api_key
elif self.authenticator == AUTH_SNOWFLAKE_JWT:
private_key_file = self.resolve_secret_value(self.private_key_file)
if private_key_file:
connection_params["private_key_file"] = private_key_file
if self.private_key_file:
connection_params["private_key_file"] = self.private_key_file
if self.private_key_file_pwd:
private_key_pwd = self.resolve_secret_value(self.private_key_file_pwd)
if private_key_pwd:
connection_params["private_key_file_pwd"] = private_key_pwd
connection_params["private_key_file_pwd"] = self.private_key_file_pwd
elif self.authenticator == AUTH_OAUTH:
client_id = self.resolve_secret_value(self.oauth_client_id)
client_secret = self.resolve_secret_value(self.oauth_client_secret)
if client_id and client_secret:
connection_params["oauth_client_id"] = client_id
connection_params["oauth_client_secret"] = client_secret
if self.oauth_client_id and self.oauth_client_secret:
connection_params["oauth_client_id"] = self.oauth_client_id
connection_params["oauth_client_secret"] = self.oauth_client_secret
if self.oauth_token_request_url:
connection_params["oauth_token_request_url"] = self.oauth_token_request_url

Expand Down
Loading