diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/auth.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/auth.py index e749f1882c..401487695a 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/auth.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/auth.py @@ -48,11 +48,11 @@ class SnowflakeAuthenticator: def __init__( self, authenticator: Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"], - api_key: Optional[Secret] = None, - private_key_file: Optional[Secret] = None, - private_key_file_pwd: Optional[Secret] = None, - oauth_client_id: Optional[Secret] = None, - oauth_client_secret: Optional[Secret] = None, + api_key: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008 + private_key_file: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008 + private_key_file_pwd: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE_PWD", strict=False), # noqa: B008 + oauth_client_id: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_CLIENT_ID", strict=False), # noqa: B008 + oauth_client_secret: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_CLIENT_SECRET", strict=False), # noqa: B008 oauth_token_request_url: Optional[str] = None, oauth_authorization_url: Optional[str] = None, ) -> None: @@ -69,11 +69,11 @@ def __init__( :param oauth_authorization_url: OAuth authorization URL. """ self.authenticator = authenticator - self.api_key = api_key - self.private_key_file = private_key_file - self.private_key_file_pwd = private_key_file_pwd - self.oauth_client_id = oauth_client_id - self.oauth_client_secret = oauth_client_secret + self.api_key = api_key.resolve_value() if api_key else None + self.private_key_file = private_key_file.resolve_value() if private_key_file else None + self.private_key_file_pwd = private_key_file_pwd.resolve_value() if private_key_file_pwd else None + self.oauth_client_id = oauth_client_id.resolve_value() if oauth_client_id else None + self.oauth_client_secret = oauth_client_secret.resolve_value() if oauth_client_secret else None self.oauth_token_request_url = oauth_token_request_url self.oauth_authorization_url = oauth_authorization_url @@ -96,29 +96,6 @@ def validate_auth_params(self) -> None: elif self.authenticator == AUTH_SNOWFLAKE: if not self.api_key: raise ValueError(ERROR_API_KEY_REQUIRED) - try: - api_key_value = self.api_key.resolve_value() - if not api_key_value: - raise ValueError(ERROR_API_KEY_REQUIRED) - except Exception as e: - msg = f"Failed to resolve api_key: {e!s}" - raise ValueError(msg) from e - - def resolve_secret_value(self, value: Optional[Secret]) -> Optional[str]: - """ - Safely resolves a Secret value. - - :param value: Secret to resolve. - :returns: Resolved string value or None. - :raises ValueError: If secret resolution fails. - """ - if value is None: - return None - try: - return value.resolve_value() - except Exception as e: - msg = f"Failed to resolve secret value: {e!s}" - raise ValueError(msg) from e def read_private_key_content(self) -> Optional[str]: """ @@ -131,13 +108,9 @@ def read_private_key_content(self) -> Optional[str]: return None try: - private_key_path = self.resolve_secret_value(self.private_key_file) - if not private_key_path: - return None - - key_path = Path(private_key_path) + key_path = Path(self.private_key_file) if not key_path.exists(): - msg = f"Private key file not found: {private_key_path}" + msg = f"Private key file not found: {self.private_key_file}" raise PrivateKeyReadError(msg) return key_path.read_text() @@ -171,14 +144,11 @@ def _build_jwt_auth_params(self, user: Optional[str] = None) -> list[str]: except Exception as e: logger.warning(f"Failed to read private key content, falling back to file path: {e!s}") # Fallback to file path (though ADBC may not support this) - private_key_path = self.resolve_secret_value(self.private_key_file) - params.append(f"private_key_file={private_key_path}") + params.append(f"private_key_file={self.private_key_file}") # Only include password parameter if it's actually set - if self.private_key_file_pwd: - private_key_pwd = self.resolve_secret_value(self.private_key_file_pwd) - if private_key_pwd: # Only add if not empty string - params.append(f"{ADBC_PARAM_JWT_KEY_PASSWORD}={private_key_pwd}") + if self.private_key_file_pwd: # Only add if not empty string + params.append(f"{ADBC_PARAM_JWT_KEY_PASSWORD}={self.private_key_file_pwd}") return params @@ -191,11 +161,9 @@ def _build_oauth_auth_params(self) -> list[str]: params = [f"authenticator={self.authenticator}"] if self.oauth_client_id: - client_id = self.resolve_secret_value(self.oauth_client_id) - params.append(f"oauth_client_id={client_id}") + params.append(f"oauth_client_id={self.oauth_client_id}") if self.oauth_client_secret: - client_secret = self.resolve_secret_value(self.oauth_client_secret) - params.append(f"oauth_client_secret={client_secret}") + params.append(f"oauth_client_secret={self.oauth_client_secret}") if self.oauth_token_request_url: params.append(f"oauth_token_request_url={self.oauth_token_request_url}") if self.oauth_authorization_url: @@ -225,7 +193,7 @@ def get_password_for_uri(self) -> Optional[str]: :raises ValueError: If secret resolution fails. """ if self.authenticator == AUTH_SNOWFLAKE and self.api_key: - return self.resolve_secret_value(self.api_key) + return self.api_key return None def create_masked_params(self, params: list) -> list[str]: @@ -241,16 +209,12 @@ def create_masked_params(self, params: list) -> list[str]: masked_param = param # Mask private key password - if self.private_key_file_pwd: - private_key_pwd = self.resolve_secret_value(self.private_key_file_pwd) - if private_key_pwd and private_key_pwd in param: - masked_param = param.replace(private_key_pwd, "***REDACTED***") + if self.private_key_file_pwd and self.private_key_file_pwd in param: + masked_param = param.replace(self.private_key_file_pwd, "***REDACTED***") # Mask OAuth client secret - if self.oauth_client_secret: - client_secret = self.resolve_secret_value(self.oauth_client_secret) - if client_secret and client_secret in param: - masked_param = masked_param.replace(client_secret, "***REDACTED***") + if self.oauth_client_secret and self.oauth_client_secret in param: + masked_param = masked_param.replace(self.oauth_client_secret, "***REDACTED***") masked_params.append(masked_param) @@ -276,23 +240,17 @@ def test_connection(self, user: str, account: str, database: Optional[str] = Non connection_params["database"] = database if self.authenticator == AUTH_SNOWFLAKE: - password = self.resolve_secret_value(self.api_key) - if password: - connection_params["password"] = password + if self.api_key: + connection_params["password"] = self.api_key elif self.authenticator == AUTH_SNOWFLAKE_JWT: - private_key_file = self.resolve_secret_value(self.private_key_file) - if private_key_file: - connection_params["private_key_file"] = private_key_file + if self.private_key_file: + connection_params["private_key_file"] = self.private_key_file if self.private_key_file_pwd: - private_key_pwd = self.resolve_secret_value(self.private_key_file_pwd) - if private_key_pwd: - connection_params["private_key_file_pwd"] = private_key_pwd + connection_params["private_key_file_pwd"] = self.private_key_file_pwd elif self.authenticator == AUTH_OAUTH: - client_id = self.resolve_secret_value(self.oauth_client_id) - client_secret = self.resolve_secret_value(self.oauth_client_secret) - if client_id and client_secret: - connection_params["oauth_client_id"] = client_id - connection_params["oauth_client_secret"] = client_secret + if self.oauth_client_id and self.oauth_client_secret: + connection_params["oauth_client_id"] = self.oauth_client_id + connection_params["oauth_client_secret"] = self.oauth_client_secret if self.oauth_token_request_url: connection_params["oauth_token_request_url"] = self.oauth_token_request_url diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index 75a1994f2f..610b452173 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2025-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from typing import Any, Dict, Literal, Optional from urllib.parse import quote_plus @@ -9,7 +10,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from pandas import DataFrame -import snowflake.connector # type: ignore[import-not-found] +import snowflake from .auth import SnowflakeAuthenticator @@ -38,6 +39,7 @@ class SnowflakeTableRetriever: db_schema="", warehouse="", ) + executor.warm_up() ``` #### Key-pair Authentication (MFA): @@ -52,6 +54,7 @@ class SnowflakeTableRetriever: db_schema="", warehouse="", ) + executor.warm_up() ``` #### OAuth Authentication (MFA): @@ -67,6 +70,7 @@ class SnowflakeTableRetriever: db_schema="", warehouse="", ) + executor.warm_up() ``` #### Running queries: @@ -97,17 +101,17 @@ def __init__( self, user: str, account: str, - authenticator: Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"], - api_key: Optional[Secret] = None, + authenticator: Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"] = "SNOWFLAKE", + api_key: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008 database: Optional[str] = None, db_schema: Optional[str] = None, warehouse: Optional[str] = None, login_timeout: Optional[int] = 60, return_markdown: bool = True, - private_key_file: Optional[Secret] = None, - private_key_file_pwd: Optional[Secret] = None, - oauth_client_id: Optional[Secret] = None, - oauth_client_secret: Optional[Secret] = None, + private_key_file: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008 + private_key_file_pwd: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD", strict=False), # noqa: B008 + oauth_client_id: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID", strict=False), # noqa: B008 + oauth_client_secret: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET", strict=False), # noqa: B008 oauth_token_request_url: Optional[str] = None, oauth_authorization_url: Optional[str] = None, ) -> None: @@ -142,32 +146,39 @@ def __init__( self.warehouse = warehouse self.login_timeout = login_timeout or 60 self.return_markdown = return_markdown - - # Initialize authentication handler + self.authenticator = authenticator + self.private_key_file = private_key_file + self.private_key_file_pwd = private_key_file_pwd + self.oauth_client_id = oauth_client_id + self.oauth_client_secret = oauth_client_secret + self.oauth_token_request_url = oauth_token_request_url + self.oauth_authorization_url = oauth_authorization_url + self.authenticator_handler: Optional[SnowflakeAuthenticator] = None + self._warmed_up = False + + def warm_up(self) -> None: + """ + Warm up the component by initializing the authenticator handler and testing the database connection. + """ + if self._warmed_up: + return self.authenticator_handler = SnowflakeAuthenticator( - authenticator=authenticator, - api_key=api_key, - private_key_file=private_key_file, - private_key_file_pwd=private_key_file_pwd, - oauth_client_id=oauth_client_id, - oauth_client_secret=oauth_client_secret, - oauth_token_request_url=oauth_token_request_url, - oauth_authorization_url=oauth_authorization_url, + authenticator=self.authenticator, + api_key=self.api_key, + private_key_file=self.private_key_file, + private_key_file_pwd=self.private_key_file_pwd, + oauth_client_id=self.oauth_client_id, + oauth_client_secret=self.oauth_client_secret, + oauth_token_request_url=self.oauth_token_request_url, + oauth_authorization_url=self.oauth_authorization_url, ) - self.authenticator = authenticator # Test connection during initialization to verify credentials - if not self.test_connection(): + if not self.authenticator_handler.test_connection(user=self.user, account=self.account, database=self.database): msg = "Failed to connect to Snowflake with provided credentials" raise ConnectionError(msg) - def test_connection(self) -> bool: - """ - Tests the connection with the current authentication settings. - - :returns: True if connection is successful, False otherwise. - """ - return self.authenticator_handler.test_connection(user=self.user, account=self.account, database=self.database) + self._warmed_up = True def to_dict(self) -> Dict[str, Any]: """ @@ -185,22 +196,14 @@ def to_dict(self) -> Dict[str, Any]: "login_timeout": self.login_timeout, "return_markdown": self.return_markdown, "authenticator": self.authenticator, - "oauth_token_request_url": self.authenticator_handler.oauth_token_request_url, - "oauth_authorization_url": self.authenticator_handler.oauth_authorization_url, + "oauth_token_request_url": self.oauth_token_request_url, + "oauth_authorization_url": self.oauth_authorization_url, + "api_key": self.api_key.to_dict() if self.api_key else None, + "private_key_file": self.private_key_file.to_dict() if self.private_key_file else None, + "private_key_file_pwd": self.private_key_file_pwd.to_dict() if self.private_key_file_pwd else None, + "oauth_client_id": self.oauth_client_id.to_dict() if self.oauth_client_id else None, + "oauth_client_secret": self.oauth_client_secret.to_dict() if self.oauth_client_secret else None, } - - # Handle Secret fields - if self.authenticator_handler.api_key: - data["api_key"] = self.authenticator_handler.api_key.to_dict() - if self.authenticator_handler.private_key_file: - data["private_key_file"] = self.authenticator_handler.private_key_file.to_dict() - if self.authenticator_handler.private_key_file_pwd: - data["private_key_file_pwd"] = self.authenticator_handler.private_key_file_pwd.to_dict() - if self.authenticator_handler.oauth_client_id: - data["oauth_client_id"] = self.authenticator_handler.oauth_client_id.to_dict() - if self.authenticator_handler.oauth_client_secret: - data["oauth_client_secret"] = self.authenticator_handler.oauth_client_secret.to_dict() - return default_to_dict(self, **data) @classmethod @@ -244,7 +247,8 @@ def _snowflake_uri_constructor(self) -> str: encoded_user = quote_plus(self.user) encoded_account = quote_plus(self.account) - password = self.authenticator_handler.get_password_for_uri() + # We ignore the mypy error since it doesn't know that self.authenticator_handler has been set at this point + password = self.authenticator_handler.get_password_for_uri() # type: ignore[union-attr] if password: # Traditional password authentication - encode password encoded_password = quote_plus(password) @@ -270,7 +274,8 @@ def _snowflake_uri_constructor(self) -> str: params.append(f"login_timeout={self.login_timeout}") # Add authentication-specific parameters (pass user for JWT ADBC support) - auth_params = self.authenticator_handler.build_auth_params(user=self.user) + # We ignore the mypy error since it doesn't know that self.authenticator_handler has been set at this point + auth_params = self.authenticator_handler.build_auth_params(user=self.user) # type: ignore[union-attr] params.extend(auth_params) if params: @@ -292,7 +297,8 @@ def _create_masked_uri(self, uri: str) -> str: # Mask password if present if self.authenticator == "SNOWFLAKE": - password = self.authenticator_handler.get_password_for_uri() + # We ignore the mypy error since it doesn't know that self.authenticator_handler has been set at this point + password = self.authenticator_handler.get_password_for_uri() # type: ignore[union-attr] if password: encoded_password = quote_plus(password) masked_uri = masked_uri.replace(encoded_password, "***REDACTED***") @@ -301,7 +307,7 @@ def _create_masked_uri(self, uri: str) -> str: if "?" in masked_uri: base_uri, query_params = masked_uri.split("?", 1) param_list = query_params.split("&") - masked_params = self.authenticator_handler.create_masked_params(param_list) + masked_params = self.authenticator_handler.create_masked_params(param_list) # type: ignore[union-attr] masked_uri = base_uri + "?" + "&".join(masked_params) return masked_uri @@ -355,18 +361,15 @@ def _execute_query_with_connector(self, query: str) -> Optional[pl.DataFrame]: # Add JWT-specific parameters if self.authenticator == "SNOWFLAKE_JWT": - private_key_file = self.authenticator_handler.resolve_secret_value( - self.authenticator_handler.private_key_file - ) - if private_key_file: - conn_params["private_key_file"] = private_key_file + # We ignore the mypy error since it doesn't know that self.authenticator_handler has been set at this + # point + if self.authenticator_handler.private_key_file: # type: ignore[union-attr] + conn_params["private_key_file"] = self.authenticator_handler.private_key_file # type: ignore[union-attr] - if self.authenticator_handler.private_key_file_pwd: - private_key_pwd = self.authenticator_handler.resolve_secret_value( - self.authenticator_handler.private_key_file_pwd - ) - if private_key_pwd: - conn_params["private_key_file_pwd"] = private_key_pwd + # We ignore the mypy error since it doesn't know that self.authenticator_handler has been set at this + # point + if self.authenticator_handler.private_key_file_pwd: # type: ignore[union-attr] + conn_params["private_key_file_pwd"] = self.authenticator_handler.private_key_file_pwd # type: ignore[union-attr] # Connect and execute query conn = snowflake.connector.connect(**conn_params) @@ -419,6 +422,10 @@ def run(self, query: str, return_markdown: Optional[bool] = None) -> Dict[str, A - `"dataframe"`: A Pandas DataFrame with the query results. - `"table"`: A Markdown-formatted string representation of the DataFrame. """ + if not self._warmed_up: + msg = "SnowflakeTableRetriever not warmed up. Please call `warm_up()` before running queries." + raise RuntimeError(msg) + # Validate SQL query if not query: logger.warning("Empty query provided, returning empty DataFrame") diff --git a/integrations/snowflake/tests/test_auth.py b/integrations/snowflake/tests/test_auth.py index d25004700f..5a6146906a 100644 --- a/integrations/snowflake/tests/test_auth.py +++ b/integrations/snowflake/tests/test_auth.py @@ -13,24 +13,6 @@ class TestSnowflakeAuthenticator: """Tests for the SnowflakeAuthenticator class.""" - def test_authenticator_resolve_secret_value_error(self, mocker: Mock, tmp_path: Path) -> None: - # Test error handling in resolve_secret_value - - # Create a mock secret that raises an exception - mock_secret = mocker.Mock() - mock_secret.resolve_value.side_effect = Exception("Failed to resolve secret") - - key_file = tmp_path / "key.pem" - key_file.write_text("-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----") - - auth = SnowflakeAuthenticator( - authenticator="SNOWFLAKE_JWT", - private_key_file=Secret.from_token(str(key_file)), - ) - - with pytest.raises(ValueError, match="Failed to resolve secret value"): - auth.resolve_secret_value(mock_secret) - def test_authenticator_read_private_key_not_found(self, mocker: Mock) -> None: # Test error handling when private key file doesn't exist @@ -85,11 +67,9 @@ def test_authenticator_validate_api_key_resolution_error(self, mocker: Mock) -> mock_secret = mocker.Mock() mock_secret.resolve_value.side_effect = Exception("Resolution failed") - with pytest.raises(ValueError, match="Failed to resolve api_key"): - SnowflakeAuthenticator( - authenticator="SNOWFLAKE", - api_key=mock_secret, - ) + with pytest.raises(ValueError, match="None of the following authentication environment"): + retriever = SnowflakeAuthenticator(authenticator="SNOWFLAKE", api_key=Secret.from_env_var("TEST_ENV")) + retriever.warm_up() def test_authenticator_validate_empty_api_key(self, mocker: Mock) -> None: # Test validation when api_key resolves to empty string diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index 6ada43a59d..b8b4a3707c 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -22,10 +22,10 @@ def retriever(mocker: Mock) -> SnowflakeTableRetriever: mocker.patch.dict(os.environ, {"SNOWFLAKE_API_KEY": "test_api_key"}) # Mock the connection test to avoid requiring actual Snowflake connection during tests mocker.patch( - "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever.test_connection", + "haystack_integrations.components.retrievers.snowflake.auth.SnowflakeAuthenticator.test_connection", return_value=True, ) - return SnowflakeTableRetriever( + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="SNOWFLAKE", @@ -35,6 +35,8 @@ def retriever(mocker: Mock) -> SnowflakeTableRetriever: warehouse="test_warehouse", return_markdown=True, ) + table_retriever.warm_up() + return table_retriever @pytest.fixture @@ -64,10 +66,10 @@ def jwt_retriever(mocker: Mock) -> SnowflakeTableRetriever: ) # Mock the connection test to avoid requiring actual Snowflake connection during tests mocker.patch( - "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever.test_connection", + "haystack_integrations.components.retrievers.snowflake.auth.SnowflakeAuthenticator.test_connection", return_value=True, ) - return SnowflakeTableRetriever( + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="SNOWFLAKE_JWT", @@ -78,6 +80,8 @@ def jwt_retriever(mocker: Mock) -> SnowflakeTableRetriever: warehouse="test_warehouse", return_markdown=True, ) + table_retriever.warm_up() + return table_retriever @pytest.fixture @@ -88,10 +92,10 @@ def oauth_retriever(mocker: Mock) -> SnowflakeTableRetriever: ) # Mock the connection test to avoid requiring actual Snowflake connection during tests mocker.patch( - "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever.test_connection", + "haystack_integrations.components.retrievers.snowflake.auth.SnowflakeAuthenticator.test_connection", return_value=True, ) - return SnowflakeTableRetriever( + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="OAUTH", @@ -103,6 +107,8 @@ def oauth_retriever(mocker: Mock) -> SnowflakeTableRetriever: warehouse="test_warehouse", return_markdown=True, ) + table_retriever.warm_up() + return table_retriever class TestSnowflakeTableRetriever: @@ -121,6 +127,18 @@ def test_init_and_serialization(self, retriever: SnowflakeTableRetriever) -> Non assert deserialized.account == "test_account" assert deserialized.return_markdown is True + def test_from_dict_minimal(self): + data = { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", # noqa: E501 + "init_parameters": { + "user": "test_user", + "account": "test_account", + "api_key": {"type": "env_var", "env_vars": ["SNOWFLAKE_API_KEY"], "strict": False}, + }, + } + deserialized = SnowflakeTableRetriever.from_dict(data) + assert isinstance(deserialized, SnowflakeTableRetriever) + @pytest.mark.parametrize( "user, account, db_name, schema_name, warehouse_name, expected_uri, should_raise", [ @@ -178,7 +196,7 @@ def test_snowflake_uri_constructor( mocker.patch.dict(os.environ, {"SNOWFLAKE_API_KEY": "test_api_key"}) # Mock connection test for direct instantiation mocker.patch( - "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever.test_connection", + "haystack_integrations.components.retrievers.snowflake.auth.SnowflakeAuthenticator.test_connection", return_value=True, ) @@ -191,6 +209,7 @@ def test_snowflake_uri_constructor( db_schema=schema_name, warehouse=warehouse_name, ) + retriever.warm_up() if should_raise: with pytest.raises( @@ -349,7 +368,7 @@ def test_run_with_markdown_parameter( ) -> None: mocker.patch.dict(os.environ, {"SNOWFLAKE_API_KEY": "test_api_key"}) mocker.patch( - "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever.test_connection", + "haystack_integrations.components.retrievers.snowflake.auth.SnowflakeAuthenticator.test_connection", return_value=True, ) retriever = SnowflakeTableRetriever( @@ -360,6 +379,7 @@ def test_run_with_markdown_parameter( database="test_db", return_markdown=False, ) + retriever.warm_up() mocker.patch("polars.read_database_uri", return_value=toy_polars_df) mocker.patch.object(toy_polars_df, "to_pandas", return_value=toy_pandas_df) @@ -401,7 +421,7 @@ def test_masked_uri_logging( def test_custom_login_timeout(self, mocker: Mock) -> None: mocker.patch.dict(os.environ, {"SNOWFLAKE_API_KEY": "test_api_key"}) mocker.patch( - "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever.test_connection", + "haystack_integrations.components.retrievers.snowflake.auth.SnowflakeAuthenticator.test_connection", return_value=True, ) custom_timeout = 120 @@ -413,6 +433,7 @@ def test_custom_login_timeout(self, mocker: Mock) -> None: database="test_db", login_timeout=custom_timeout, ) + retriever.warm_up() uri = retriever._snowflake_uri_constructor() expected_uri = f"snowflake://test_user:test_api_key@test_account/test_db?login_timeout={custom_timeout}" @@ -481,44 +502,23 @@ def test_masked_uri_logging_oauth(self, oauth_retriever: SnowflakeTableRetriever ], ) def test_authentication_validation_errors( - self, mocker: Mock, authenticator: str, missing_param: str, expected_error: str + self, authenticator: str, missing_param: str, expected_error: str, monkeypatch ) -> None: # Set up environment variables, excluding the one being tested as missing - env_vars = { - "SNOWFLAKE_PRIVATE_KEY_FILE": "/path/to/key.pem", - "SNOWFLAKE_PRIVATE_KEY_PWD": "test_password", - "SNOWFLAKE_OAUTH_CLIENT_ID": "test_client_id", - "SNOWFLAKE_OAUTH_CLIENT_SECRET": "test_client_secret", - } - - # Only set SNOWFLAKE_API_KEY if we're not testing its absence - if not (authenticator == "SNOWFLAKE" and missing_param == "api_key"): - env_vars["SNOWFLAKE_API_KEY"] = "test_api_key" - - mocker.patch.dict(os.environ, env_vars, clear=True) - - kwargs = { - "user": "test_user", - "account": "test_account", - "authenticator": authenticator, - } - if authenticator == "SNOWFLAKE_JWT": - if missing_param != "private_key_file": - kwargs["private_key_file"] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE") - if missing_param != "private_key_file_pwd": - kwargs["private_key_file_pwd"] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD") + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PWD", "test_password") elif authenticator == "OAUTH": - if missing_param != "oauth_client_id": - kwargs["oauth_client_id"] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID") - if missing_param != "oauth_client_secret": - kwargs["oauth_client_secret"] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET") - elif authenticator == "SNOWFLAKE": - if missing_param != "api_key": - kwargs["api_key"] = Secret.from_env_var("SNOWFLAKE_API_KEY") + if missing_param == "oauth_client_id": + monkeypatch.setenv("SNOWFLAKE_OAUTH_CLIENT_SECRET", "test_client_secret") + else: + monkeypatch.setenv("SNOWFLAKE_OAUTH_CLIENT_ID", "test_client_id") + kwargs = {"user": "test_user", "account": "test_account", "authenticator": authenticator} + + # Validation errors are raised during warm_up (which calls test_connection) with pytest.raises(ValueError, match=expected_error): - SnowflakeTableRetriever(**kwargs) + table_retriever = SnowflakeTableRetriever(**kwargs) + table_retriever.warm_up() def test_jwt_authentication_happy_path( self, @@ -563,13 +563,15 @@ def test_connection_success(self, mocker: Mock) -> None: mock_connection = mocker.Mock() mock_connect = mocker.patch("snowflake.connector.connect", return_value=mock_connection) - # Create retriever (test_connection will be called during init and will use the mock) - SnowflakeTableRetriever( + # Create retriever + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="SNOWFLAKE", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), ) + # test_connection will be called during warm up + table_retriever.warm_up() # Verify the connection was tested during initialization assert mock_connect.call_count >= 1 @@ -583,14 +585,15 @@ def test_connection_failure(self, mocker: Mock) -> None: mock_snowflake.connector.connect.side_effect = Exception("Connection failed") mocker.patch.dict("sys.modules", {"snowflake": mock_snowflake, "snowflake.connector": mock_snowflake.connector}) - # Should raise ConnectionError during initialization + # Should raise ConnectionError during warm up with pytest.raises(ConnectionError, match="Failed to connect to Snowflake"): - SnowflakeTableRetriever( + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="SNOWFLAKE", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), ) + table_retriever.warm_up() def test_connection_jwt_auth(self, mocker: Mock, tmp_path: Path) -> None: # Create a temporary key file @@ -606,14 +609,16 @@ def test_connection_jwt_auth(self, mocker: Mock, tmp_path: Path) -> None: mock_connection = mocker.Mock() mock_connect = mocker.patch("snowflake.connector.connect", return_value=mock_connection) - # Create JWT retriever (test_connection will be called during init) - SnowflakeTableRetriever( + # Create JWT retriever + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="SNOWFLAKE_JWT", private_key_file=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE"), private_key_file_pwd=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD"), ) + # test_connection will be called during warm up + table_retriever.warm_up() # Verify that JWT-specific parameters were passed assert mock_connect.call_count >= 1 @@ -632,14 +637,16 @@ def test_connection_oauth_auth(self, mocker: Mock) -> None: mock_connection = mocker.Mock() mock_connect = mocker.patch("snowflake.connector.connect", return_value=mock_connection) - # Create OAuth retriever (test_connection will be called during init) - SnowflakeTableRetriever( + # Create OAuth retriever + table_retriever = SnowflakeTableRetriever( user="test_user", account="test_account", authenticator="OAUTH", oauth_client_id=Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID"), oauth_client_secret=Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET"), ) + # test_connection will be called during warm up + table_retriever.warm_up() # Verify that OAuth-specific parameters were passed assert mock_connect.call_count >= 1 @@ -724,6 +731,7 @@ def test_run_jwt_auth_flow(self, mocker: Mock, toy_polars_df: pl.DataFrame, tmp_ private_key_file=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE"), private_key_file_pwd=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD"), ) + jwt_retriever.warm_up() # Mock _execute_query_with_connector to return toy data mocker.patch.object(jwt_retriever, "_execute_query_with_connector", return_value=toy_polars_df)