diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 1fc5894c..2485fb71 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -12,6 +12,7 @@ from oauthlib.oauth2.rfc6749.errors import OAuth2Error from databricks.sql.common.http import HttpMethod, HttpHeader from databricks.sql.common.http import OAuthResponse +from databricks.sql.auth.retry import CommandType from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler from databricks.sql.auth.endpoint import OAuthEndpointCollection from abc import abstractmethod, ABC @@ -87,6 +88,8 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: + # Set command type for OAuth configuration request + self.http_client.setRequestType(CommandType.AUTH) response = self.http_client.request(HttpMethod.GET, url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status @@ -195,6 +198,8 @@ def __send_token_request(self, token_request_url, data): "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", } + # Set command type for OAuth token request + self.http_client.setRequestType(CommandType.AUTH) # Use unified HTTP client response = self.http_client.request( HttpMethod.POST, url=token_request_url, body=data, headers=headers @@ -337,6 +342,8 @@ def refresh(self) -> Token: } ) + # Set command type for OAuth client credentials request + self._http_client.setRequestType(CommandType.AUTH) response = self._http_client.request( method=HttpMethod.POST, url=self.token_url, headers=headers, body=data ) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 4281883d..79604f3e 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -32,10 +32,20 @@ class CommandType(Enum): - EXECUTE_STATEMENT = "ExecuteStatement" + NOT_SET = "NotSet" + OPEN_SESSION = "OpenSession" CLOSE_SESSION = "CloseSession" + METADATA = "Metadata" CLOSE_OPERATION = "CloseOperation" - GET_OPERATION_STATUS = "GetOperationStatus" + CANCEL_OPERATION = "CancelOperation" + EXECUTE_STATEMENT = "ExecuteStatement" + FETCH_RESULTS = "FetchResults" + CLOUD_FETCH = "CloudFetch" + AUTH = "Auth" + TELEMETRY_PUSH = "TelemetryPush" + VOLUME_GET = "VolumeGet" + VOLUME_PUT = "VolumePut" + VOLUME_DELETE = "VolumeDelete" OTHER = "Other" @classmethod @@ -45,9 +55,66 @@ def get(cls, value: str): if valid_command: return getattr(cls, str(valid_command)) else: + # Map Thrift metadata operations to METADATA type + metadata_operations = { + "GetOperationStatus", "GetResultSetMetadata", "GetTables", + "GetColumns", "GetSchemas", "GetCatalogs", "GetFunctions", + "GetPrimaryKeys", "GetTypeInfo", "GetCrossReference", + "GetImportedKeys", "GetExportedKeys", "GetTableTypes" + } + if value in metadata_operations: + return cls.METADATA return cls.OTHER +class CommandIdempotency(Enum): + IDEMPOTENT = "idempotent" + NON_IDEMPOTENT = "non_idempotent" + + +# Mapping of CommandType to CommandIdempotency +# Based on the official idempotency classification +COMMAND_IDEMPOTENCY_MAP = { + # NON-IDEMPOTENT operations (safety first - unknown types are not retried) + CommandType.NOT_SET: CommandIdempotency.NON_IDEMPOTENT, + CommandType.EXECUTE_STATEMENT: CommandIdempotency.NON_IDEMPOTENT, + CommandType.FETCH_RESULTS: CommandIdempotency.NON_IDEMPOTENT, + CommandType.VOLUME_PUT: CommandIdempotency.NON_IDEMPOTENT, # PUT can overwrite files + + # IDEMPOTENT operations + CommandType.OPEN_SESSION: CommandIdempotency.IDEMPOTENT, + CommandType.CLOSE_SESSION: CommandIdempotency.IDEMPOTENT, + CommandType.METADATA: CommandIdempotency.IDEMPOTENT, + CommandType.CLOSE_OPERATION: CommandIdempotency.IDEMPOTENT, + CommandType.CANCEL_OPERATION: CommandIdempotency.IDEMPOTENT, + CommandType.CLOUD_FETCH: CommandIdempotency.IDEMPOTENT, + CommandType.AUTH: CommandIdempotency.IDEMPOTENT, + CommandType.TELEMETRY_PUSH: CommandIdempotency.IDEMPOTENT, + CommandType.VOLUME_GET: CommandIdempotency.IDEMPOTENT, + CommandType.VOLUME_DELETE: CommandIdempotency.IDEMPOTENT, + CommandType.OTHER: CommandIdempotency.IDEMPOTENT, +} + +# HTTP status codes that should never be retried, even for idempotent requests +# These are client error codes that indicate permanent issues +NON_RETRYABLE_STATUS_CODES = { + 400, # Bad Request + 401, # Unauthorized + 403, # Forbidden + 404, # Not Found + 405, # Method Not Allowed + 409, # Conflict + 410, # Gone + 411, # Length Required + 412, # Precondition Failed + 413, # Payload Too Large + 414, # URI Too Long + 415, # Unsupported Media Type + 416, # Range Not Satisfiable + 501, # Not Implemented +} + + class DatabricksRetryPolicy(Retry): """ Implements our v3 retry policy by extending urllib3's robust default retry behaviour. @@ -354,38 +421,25 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: logger.info(f"Received status code {status_code} for {method} request") + # Get command idempotency for use in multiple conditions below + command_idempotency = COMMAND_IDEMPOTENCY_MAP.get( + self.command_type, CommandIdempotency.NON_IDEMPOTENT + ) + # Request succeeded. Don't retry. if status_code // 100 <= 3: return False, "2xx/3xx codes are not retried" - if status_code == 400: - return ( - False, - "Received 400 - BAD_REQUEST. Please check the request parameters.", - ) - - if status_code == 401: - return ( - False, - "Received 401 - UNAUTHORIZED. Confirm your authentication credentials.", - ) - - if status_code == 403: - return False, "403 codes are not retried" - - # Request failed and server said NotImplemented. This isn't recoverable. Don't retry. - if status_code == 501: - return False, "Received code 501 from server." # Request failed and this method is not retryable. We only retry POST requests. if not self._is_method_retryable(method): return False, "Only POST requests are retried" # Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry. - if status_code == 404 and self.command_type == CommandType.GET_OPERATION_STATUS: + if status_code == 404 and self.command_type == CommandType.METADATA: return ( False, - "GetOperationStatus received 404 code from Databricks. Operation was canceled.", + "Metadata request received 404 code from Databricks. Operation was canceled.", ) # Request failed with 404 because CloseSession returns 404 if you repeat the request. @@ -408,15 +462,18 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: "CloseOperation received 404 code from Databricks. Cursor is already closed." ) + if status_code in NON_RETRYABLE_STATUS_CODES: + return False, f"Received {status_code} code from Databricks. Operation was canceled." + # Request failed, was an ExecuteStatement and the command may have reached the server if ( - self.command_type == CommandType.EXECUTE_STATEMENT + command_idempotency == CommandIdempotency.NON_IDEMPOTENT and status_code not in self.status_forcelist and status_code not in self.force_dangerous_codes ): return ( False, - "ExecuteStatement command can only be retried for codes 429 and 503", + "Non Idempotent requests can only be retried for codes 429 and 503", ) # Request failed with a dangerous code, was an ExecuteStatement, but user forced retries for this @@ -424,7 +481,7 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: # retry automatically. This code is included only so that we can log the exact reason for the retry. # This gives users signal that their _retry_dangerous_codes setting actually did something. if ( - self.command_type == CommandType.EXECUTE_STATEMENT + command_idempotency == CommandIdempotency.NON_IDEMPOTENT and status_code in self.force_dangerous_codes ): return ( diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index b47f2add..821d2e4f 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -285,13 +285,19 @@ def _get_command_type_from_path(self, path: str, method: str) -> CommandType: if method == "POST" and path.endswith("/statements"): return CommandType.EXECUTE_STATEMENT elif "/cancel" in path: - return CommandType.OTHER # Cancel operation + return CommandType.CANCEL_OPERATION elif method == "DELETE": return CommandType.CLOSE_OPERATION elif method == "GET": - return CommandType.GET_OPERATION_STATUS + # For GET requests on statements, determine if it's fetching results or status + if "/result/chunks/" in path: + return CommandType.FETCH_RESULTS + else: + return CommandType.METADATA # Statement status queries elif "/sessions" in path: - if method == "DELETE": + if method == "POST" and path.endswith("/sessions"): + return CommandType.OPEN_SESSION + elif method == "DELETE": return CommandType.CLOSE_SESSION - return CommandType.OTHER + return CommandType.NOT_SET diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 78a01142..cd13efec 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -25,6 +25,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.auth.retry import CommandType from databricks.sql.utils import ( ParamEscaper, inject_parameters, @@ -774,6 +775,9 @@ def _handle_staging_put( session_id_hex=self.connection.get_session_id_hex(), ) + # Set command type for volume PUT operation + self.connection.http_client.setRequestType(CommandType.VOLUME_PUT) + with open(local_file, "rb") as fh: r = self.connection.http_client.request( HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers @@ -830,6 +834,9 @@ def _handle_staging_put_stream( session_id_hex=self.connection.get_session_id_hex(), ) + # Set command type for volume PUT stream operation + self.connection.http_client.setRequestType(CommandType.VOLUME_PUT) + r = self.connection.http_client.request( HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers ) @@ -851,6 +858,9 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) + # Set command type for volume GET operation + self.connection.http_client.setRequestType(CommandType.VOLUME_GET) + r = self.connection.http_client.request( HttpMethod.GET, presigned_url, headers=headers ) @@ -874,6 +884,9 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" + # Set command type for volume DELETE operation + self.connection.http_client.setRequestType(CommandType.VOLUME_DELETE) + r = self.connection.http_client.request( HttpMethod.DELETE, presigned_url, headers=headers ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 7ccd69c5..da201826 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -218,11 +218,37 @@ def _prepare_headers( def _prepare_retry_policy(self): """Set up the retry policy for the current request.""" if isinstance(self._retry_policy, DatabricksRetryPolicy): - # Set command type for HTTP requests to OTHER (not database commands) - self._retry_policy.command_type = CommandType.OTHER + # Only set command type to NOT_SET if it hasn't been explicitly set via setRequestType() + if self._retry_policy.command_type is None: + self._retry_policy.command_type = CommandType.NOT_SET # Start the retry timer for duration-based retry limits self._retry_policy.start_retry_timer() + def setRequestType(self, request_type: CommandType): + """ + Set the specific request type for the next HTTP request. + + This allows clients to specify what type of operation they're performing + so the retry policy can make appropriate idempotency decisions. + + Args: + request_type: The CommandType enum value for this operation + + Example: + # For authentication requests (OAuth, etc.) + http_client.setRequestType(CommandType.AUTH) + response = http_client.request(HttpMethod.POST, url, body=data) + + # For cloud fetch operations + http_client.setRequestType(CommandType.CLOUD_FETCH) + response = http_client.request(HttpMethod.GET, cloud_url) + """ + if isinstance(self._retry_policy, DatabricksRetryPolicy): + self._retry_policy.command_type = request_type + logger.debug(f"Set request type to: {request_type.value}") + else: + logger.warning(f"Cannot set request type {request_type.value}: retry policy is not DatabricksRetryPolicy") + @contextmanager def request_context( self, @@ -269,6 +295,11 @@ def request_context( logger.error("HTTP request error: %s", e) raise RequestError(f"HTTP request error: {e}") finally: + # Reset command type after request completion to prevent it from affecting subsequent requests + if isinstance(self._retry_policy, DatabricksRetryPolicy): + self._retry_policy.command_type = None + logger.debug("Reset command type after request completion") + if response: response.close() diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index dd7c5699..779cc57c 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -33,7 +33,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): rows = self.get_some_rows(cursor, fetchmany_size) if not rows: # Read all the rows, row_count should match - self.assertEqual(n, row_count) + assert n == row_count num_fetches = max(math.ceil(n / 10000), 1) latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b2350bd9..aab219a8 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -278,7 +278,7 @@ def test_retry_max_count_not_exceeded(self, mock_send_telemetry, extra_params): THEN the connector issues six request (original plus five retries) before raising an exception """ - with mocked_server_response(status=404) as mock_obj: + with mocked_server_response(status=429) as mock_obj: with pytest.raises(MaxRetryError) as cm: extra_params = {**extra_params, **self._retry_policy} with self.connection(extra_params=extra_params) as conn: diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 73aa0a11..a88f5523 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -81,7 +81,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail with pytest.raises( - Error, match="too many 404 error responses" + Error, match="Staging operation over HTTP was unsuccessful: 404" ): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 93e63bd2..5b4086f9 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -81,7 +81,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail with pytest.raises( - Error, match="too many 404 error responses" + Error, match="Staging operation over HTTP was unsuccessful: 404" ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" diff --git a/tests/unit/test_sea_http_client.py b/tests/unit/test_sea_http_client.py index 39ecb58a..3250d633 100644 --- a/tests/unit/test_sea_http_client.py +++ b/tests/unit/test_sea_http_client.py @@ -57,7 +57,7 @@ def test_get_command_type_from_path(self, sea_http_client): sea_http_client._get_command_type_from_path( "/statements/123/cancel", "POST" ) - == CommandType.OTHER + == CommandType.CANCEL_OPERATION ) # Test statement deletion (close operation) @@ -69,7 +69,13 @@ def test_get_command_type_from_path(self, sea_http_client): # Test get statement status assert ( sea_http_client._get_command_type_from_path("/statements/123", "GET") - == CommandType.GET_OPERATION_STATUS + == CommandType.METADATA + ) + + # Test session creation + assert ( + sea_http_client._get_command_type_from_path("/sessions", "POST") + == CommandType.OPEN_SESSION ) # Test session close @@ -78,14 +84,20 @@ def test_get_command_type_from_path(self, sea_http_client): == CommandType.CLOSE_SESSION ) + # Test result chunk fetching + assert ( + sea_http_client._get_command_type_from_path("/statements/123/result/chunks/0", "GET") + == CommandType.FETCH_RESULTS + ) + # Test other paths assert ( sea_http_client._get_command_type_from_path("/other/endpoint", "GET") - == CommandType.OTHER + == CommandType.NOT_SET ) assert ( sea_http_client._get_command_type_from_path("/other/endpoint", "POST") - == CommandType.OTHER + == CommandType.NOT_SET ) @patch(