Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/databricks/sql/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from databricks.sql.common.http import HttpMethod, HttpHeader
from databricks.sql.common.http import OAuthResponse
from databricks.sql.auth.retry import CommandType
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
from databricks.sql.auth.endpoint import OAuthEndpointCollection
from abc import abstractmethod, ABC
Expand Down Expand Up @@ -87,6 +88,8 @@ def __fetch_well_known_config(self, hostname: str):
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)

try:
# Set command type for OAuth configuration request
self.http_client.setRequestType(CommandType.AUTH)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This again is imposing thread safety concerns. we need to figure out a way to avoid having a state in httpclient.

response = self.http_client.request(HttpMethod.GET, url=known_config_url)
# Convert urllib3 response to requests-like response for compatibility
response.status_code = response.status
Expand Down Expand Up @@ -195,6 +198,8 @@ def __send_token_request(self, token_request_url, data):
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
# Set command type for OAuth token request
self.http_client.setRequestType(CommandType.AUTH)
# Use unified HTTP client
response = self.http_client.request(
HttpMethod.POST, url=token_request_url, body=data, headers=headers
Expand Down Expand Up @@ -337,6 +342,8 @@ def refresh(self) -> Token:
}
)

# Set command type for OAuth client credentials request
self._http_client.setRequestType(CommandType.AUTH)
response = self._http_client.request(
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
)
Expand Down
107 changes: 82 additions & 25 deletions src/databricks/sql/auth/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@


class CommandType(Enum):
EXECUTE_STATEMENT = "ExecuteStatement"
NOT_SET = "NotSet"
OPEN_SESSION = "OpenSession"
CLOSE_SESSION = "CloseSession"
METADATA = "Metadata"
CLOSE_OPERATION = "CloseOperation"
GET_OPERATION_STATUS = "GetOperationStatus"
CANCEL_OPERATION = "CancelOperation"
EXECUTE_STATEMENT = "ExecuteStatement"
FETCH_RESULTS = "FetchResults"
CLOUD_FETCH = "CloudFetch"
AUTH = "Auth"
TELEMETRY_PUSH = "TelemetryPush"
VOLUME_GET = "VolumeGet"
VOLUME_PUT = "VolumePut"
VOLUME_DELETE = "VolumeDelete"
OTHER = "Other"

@classmethod
Expand All @@ -45,9 +55,66 @@ def get(cls, value: str):
if valid_command:
return getattr(cls, str(valid_command))
else:
# Map Thrift metadata operations to METADATA type
metadata_operations = {
"GetOperationStatus", "GetResultSetMetadata", "GetTables",
"GetColumns", "GetSchemas", "GetCatalogs", "GetFunctions",
"GetPrimaryKeys", "GetTypeInfo", "GetCrossReference",
"GetImportedKeys", "GetExportedKeys", "GetTableTypes"
}
if value in metadata_operations:
return cls.METADATA
return cls.OTHER


class CommandIdempotency(Enum):
IDEMPOTENT = "idempotent"
NON_IDEMPOTENT = "non_idempotent"


# Mapping of CommandType to CommandIdempotency
# Based on the official idempotency classification
COMMAND_IDEMPOTENCY_MAP = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link the retry design doc section here that classifies requests

# NON-IDEMPOTENT operations (safety first - unknown types are not retried)
CommandType.NOT_SET: CommandIdempotency.NON_IDEMPOTENT,
CommandType.EXECUTE_STATEMENT: CommandIdempotency.NON_IDEMPOTENT,
CommandType.FETCH_RESULTS: CommandIdempotency.NON_IDEMPOTENT,
CommandType.VOLUME_PUT: CommandIdempotency.NON_IDEMPOTENT, # PUT can overwrite files

# IDEMPOTENT operations
CommandType.OPEN_SESSION: CommandIdempotency.IDEMPOTENT,
CommandType.CLOSE_SESSION: CommandIdempotency.IDEMPOTENT,
CommandType.METADATA: CommandIdempotency.IDEMPOTENT,
CommandType.CLOSE_OPERATION: CommandIdempotency.IDEMPOTENT,
CommandType.CANCEL_OPERATION: CommandIdempotency.IDEMPOTENT,
CommandType.CLOUD_FETCH: CommandIdempotency.IDEMPOTENT,
CommandType.AUTH: CommandIdempotency.IDEMPOTENT,
CommandType.TELEMETRY_PUSH: CommandIdempotency.IDEMPOTENT,
CommandType.VOLUME_GET: CommandIdempotency.IDEMPOTENT,
CommandType.VOLUME_DELETE: CommandIdempotency.IDEMPOTENT,
CommandType.OTHER: CommandIdempotency.IDEMPOTENT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what comes under CommandType.OTHER? Please clarify

}

# HTTP status codes that should never be retried, even for idempotent requests
# These are client error codes that indicate permanent issues
NON_RETRYABLE_STATUS_CODES = {
400, # Bad Request
401, # Unauthorized
403, # Forbidden
404, # Not Found
405, # Method Not Allowed
409, # Conflict
410, # Gone
411, # Length Required
412, # Precondition Failed
413, # Payload Too Large
414, # URI Too Long
415, # Unsupported Media Type
416, # Range Not Satisfiable
501, # Not Implemented
}


class DatabricksRetryPolicy(Retry):
"""
Implements our v3 retry policy by extending urllib3's robust default retry behaviour.
Expand Down Expand Up @@ -354,38 +421,25 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:

logger.info(f"Received status code {status_code} for {method} request")

# Get command idempotency for use in multiple conditions below
command_idempotency = COMMAND_IDEMPOTENCY_MAP.get(
self.command_type, CommandIdempotency.NON_IDEMPOTENT
)

# Request succeeded. Don't retry.
if status_code // 100 <= 3:
return False, "2xx/3xx codes are not retried"

if status_code == 400:
return (
False,
"Received 400 - BAD_REQUEST. Please check the request parameters.",
)

if status_code == 401:
return (
False,
"Received 401 - UNAUTHORIZED. Confirm your authentication credentials.",
)

if status_code == 403:
return False, "403 codes are not retried"

# Request failed and server said NotImplemented. This isn't recoverable. Don't retry.
if status_code == 501:
return False, "Received code 501 from server."

# Request failed and this method is not retryable. We only retry POST requests.
if not self._is_method_retryable(method):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this method defined?

return False, "Only POST requests are retried"

# Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment mentions Request was a GetOperationStatus but in if condition you're checking for all metadata command types. Is this intended?

if status_code == 404 and self.command_type == CommandType.GET_OPERATION_STATUS:
if status_code == 404 and self.command_type == CommandType.METADATA:
return (
False,
"GetOperationStatus received 404 code from Databricks. Operation was canceled.",
"Metadata request received 404 code from Databricks. Operation was canceled.",
)

# Request failed with 404 because CloseSession returns 404 if you repeat the request.
Expand All @@ -408,23 +462,26 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
"CloseOperation received 404 code from Databricks. Cursor is already closed."
)

if status_code in NON_RETRYABLE_STATUS_CODES:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return more descriptive error message, based on status_code.
eg: UNAUTHORIZED for 403, BAD_REQUEST for 400 etc

return False, f"Received {status_code} code from Databricks. Operation was canceled."

# Request failed, was an ExecuteStatement and the command may have reached the server
if (
self.command_type == CommandType.EXECUTE_STATEMENT
command_idempotency == CommandIdempotency.NON_IDEMPOTENT
and status_code not in self.status_forcelist
and status_code not in self.force_dangerous_codes
):
return (
False,
"ExecuteStatement command can only be retried for codes 429 and 503",
"Non Idempotent requests can only be retried for codes 429 and 503",
)

# Request failed with a dangerous code, was an ExecuteStatement, but user forced retries for this
# dangerous code. Note that these lines _are not required_ to make these requests retry. They would
# retry automatically. This code is included only so that we can log the exact reason for the retry.
# This gives users signal that their _retry_dangerous_codes setting actually did something.
if (
self.command_type == CommandType.EXECUTE_STATEMENT
command_idempotency == CommandIdempotency.NON_IDEMPOTENT
and status_code in self.force_dangerous_codes
):
return (
Expand Down
14 changes: 10 additions & 4 deletions src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,19 @@ def _get_command_type_from_path(self, path: str, method: str) -> CommandType:
if method == "POST" and path.endswith("/statements"):
return CommandType.EXECUTE_STATEMENT
elif "/cancel" in path:
return CommandType.OTHER # Cancel operation
return CommandType.CANCEL_OPERATION
elif method == "DELETE":
return CommandType.CLOSE_OPERATION
elif method == "GET":
return CommandType.GET_OPERATION_STATUS
# For GET requests on statements, determine if it's fetching results or status
if "/result/chunks/" in path:
return CommandType.FETCH_RESULTS
else:
return CommandType.METADATA # Statement status queries
elif "/sessions" in path:
if method == "DELETE":
if method == "POST" and path.endswith("/sessions"):
return CommandType.OPEN_SESSION
elif method == "DELETE":
return CommandType.CLOSE_SESSION

return CommandType.OTHER
return CommandType.NOT_SET
13 changes: 13 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.auth.retry import CommandType
from databricks.sql.utils import (
ParamEscaper,
inject_parameters,
Expand Down Expand Up @@ -774,6 +775,9 @@ def _handle_staging_put(
session_id_hex=self.connection.get_session_id_hex(),
)

# Set command type for volume PUT operation
self.connection.http_client.setRequestType(CommandType.VOLUME_PUT)

with open(local_file, "rb") as fh:
r = self.connection.http_client.request(
HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers
Expand Down Expand Up @@ -830,6 +834,9 @@ def _handle_staging_put_stream(
session_id_hex=self.connection.get_session_id_hex(),
)

# Set command type for volume PUT stream operation
self.connection.http_client.setRequestType(CommandType.VOLUME_PUT)

r = self.connection.http_client.request(
HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers
)
Expand All @@ -851,6 +858,9 @@ def _handle_staging_get(
session_id_hex=self.connection.get_session_id_hex(),
)

# Set command type for volume GET operation
self.connection.http_client.setRequestType(CommandType.VOLUME_GET)

r = self.connection.http_client.request(
HttpMethod.GET, presigned_url, headers=headers
)
Expand All @@ -874,6 +884,9 @@ def _handle_staging_remove(
):
"""Make an HTTP DELETE request to the presigned_url"""

# Set command type for volume DELETE operation
self.connection.http_client.setRequestType(CommandType.VOLUME_DELETE)

r = self.connection.http_client.request(
HttpMethod.DELETE, presigned_url, headers=headers
)
Expand Down
35 changes: 33 additions & 2 deletions src/databricks/sql/common/unified_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,37 @@ def _prepare_headers(
def _prepare_retry_policy(self):
"""Set up the retry policy for the current request."""
if isinstance(self._retry_policy, DatabricksRetryPolicy):
# Set command type for HTTP requests to OTHER (not database commands)
self._retry_policy.command_type = CommandType.OTHER
# Only set command type to NOT_SET if it hasn't been explicitly set via setRequestType()
if self._retry_policy.command_type is None:
self._retry_policy.command_type = CommandType.NOT_SET
# Start the retry timer for duration-based retry limits
self._retry_policy.start_retry_timer()

def setRequestType(self, request_type: CommandType):
"""
Set the specific request type for the next HTTP request.

This allows clients to specify what type of operation they're performing
so the retry policy can make appropriate idempotency decisions.

Args:
request_type: The CommandType enum value for this operation

Example:
# For authentication requests (OAuth, etc.)
http_client.setRequestType(CommandType.AUTH)
response = http_client.request(HttpMethod.POST, url, body=data)

# For cloud fetch operations
http_client.setRequestType(CommandType.CLOUD_FETCH)
response = http_client.request(HttpMethod.GET, cloud_url)
"""
if isinstance(self._retry_policy, DatabricksRetryPolicy):
self._retry_policy.command_type = request_type
logger.debug(f"Set request type to: {request_type.value}")
else:
logger.warning(f"Cannot set request type {request_type.value}: retry policy is not DatabricksRetryPolicy")

@contextmanager
def request_context(
self,
Expand Down Expand Up @@ -269,6 +295,11 @@ def request_context(
logger.error("HTTP request error: %s", e)
raise RequestError(f"HTTP request error: {e}")
finally:
# Reset command type after request completion to prevent it from affecting subsequent requests
if isinstance(self._retry_policy, DatabricksRetryPolicy):
self._retry_policy.command_type = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't setting command_type to NOT_SET better? it will be type safe

logger.debug("Reset command type after request completion")

if response:
response.close()

Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/common/large_queries_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
rows = self.get_some_rows(cursor, fetchmany_size)
if not rows:
# Read all the rows, row_count should match
self.assertEqual(n, row_count)
assert n == row_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related to this PR


num_fetches = max(math.ceil(n / 10000), 1)
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/common/retry_test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_retry_max_count_not_exceeded(self, mock_send_telemetry, extra_params):
THEN the connector issues six request (original plus five retries)
before raising an exception
"""
with mocked_server_response(status=404) as mock_obj:
with mocked_server_response(status=429) as mock_obj:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

with pytest.raises(MaxRetryError) as cm:
extra_params = {**extra_params, **self._retry_policy}
with self.connection(extra_params=extra_params) as conn:
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/common/staging_ingestion_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user):
# GET after REMOVE should fail

with pytest.raises(
Error, match="too many 404 error responses"
Error, match="Staging operation over HTTP was unsuccessful: 404"
):
cursor = conn.cursor()
query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'"
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/common/uc_volume_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_uc_volume_life_cycle(self, catalog, schema):
# GET after REMOVE should fail

with pytest.raises(
Error, match="too many 404 error responses"
Error, match="Staging operation over HTTP was unsuccessful: 404"
):
cursor = conn.cursor()
query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'"
Expand Down
Loading
Loading