diff --git a/examples/streaming_put.py b/examples/streaming_put.py new file mode 100644 index 00000000..4e769709 --- /dev/null +++ b/examples/streaming_put.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +""" +Simple example of streaming PUT operations. + +This demonstrates the basic usage of streaming PUT with the __input_stream__ token. +""" + +import io +import os +from databricks import sql + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Create a simple data stream + data = b"Hello, streaming world!" + stream = io.BytesIO(data) + + # Get catalog, schema, and volume from environment variables + catalog = os.getenv("DATABRICKS_CATALOG") + schema = os.getenv("DATABRICKS_SCHEMA") + volume = os.getenv("DATABRICKS_VOLUME") + + # Upload to Unity Catalog volume + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE", + input_stream=stream + ) + + print("File uploaded successfully!") \ No newline at end of file diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcac..78a01142 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Tuple, List, Optional, Any, Union, Sequence +from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO import pandas try: @@ -662,7 +662,9 @@ def _check_not_closed(self): ) def _handle_staging_operation( - self, staging_allowed_local_path: Union[None, str, List[str]] + self, + staging_allowed_local_path: Union[None, str, List[str]], + input_stream: Optional[BinaryIO] = None, ): """Fetch the HTTP request instruction from a staging ingestion command and call the designated handler. @@ -671,6 +673,28 @@ def _handle_staging_operation( is not descended from staging_allowed_local_path. """ + assert self.active_result_set is not None + row = self.active_result_set.fetchone() + assert row is not None + + # May be real headers, or could be json string + headers = ( + json.loads(row.headers) if isinstance(row.headers, str) else row.headers + ) + headers = dict(headers) if headers else {} + + # Handle __input_stream__ token for PUT operations + if ( + row.operation == "PUT" + and getattr(row, "localFile", None) == "__input_stream__" + ): + return self._handle_staging_put_stream( + presigned_url=row.presignedUrl, + stream=input_stream, + headers=headers, + ) + + # For non-streaming operations, validate staging_allowed_local_path if isinstance(staging_allowed_local_path, type(str())): _staging_allowed_local_paths = [staging_allowed_local_path] elif isinstance(staging_allowed_local_path, type(list())): @@ -685,10 +709,6 @@ def _handle_staging_operation( os.path.abspath(i) for i in _staging_allowed_local_paths ] - assert self.active_result_set is not None - row = self.active_result_set.fetchone() - assert row is not None - # Must set to None in cases where server response does not include localFile abs_localFile = None @@ -711,19 +731,16 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - # May be real headers, or could be json string - headers = ( - json.loads(row.headers) if isinstance(row.headers, str) else row.headers - ) - handler_args = { "presigned_url": row.presignedUrl, "local_file": abs_localFile, - "headers": dict(headers) or {}, + "headers": headers, } logger.debug( - f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}" + "Attempting staging operation indicated by server: %s - %s", + row.operation, + getattr(row, "localFile", ""), ) # TODO: Create a retry loop here to re-attempt if the request times out or fails @@ -762,6 +779,10 @@ def _handle_staging_put( HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers ) + self._handle_staging_http_response(r) + + def _handle_staging_http_response(self, r): + # fmt: off # HTTP status codes OK = 200 @@ -784,6 +805,37 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) + @log_latency(StatementType.SQL) + def _handle_staging_put_stream( + self, + presigned_url: str, + stream: BinaryIO, + headers: dict = {}, + ) -> None: + """Handle PUT operation with streaming data. + + Args: + presigned_url: The presigned URL for upload + stream: Binary stream to upload + headers: HTTP headers + + Raises: + ProgrammingError: If no input stream is provided + OperationalError: If the upload fails + """ + + if not stream: + raise ProgrammingError( + "No input stream provided for streaming operation", + session_id_hex=self.connection.get_session_id_hex(), + ) + + r = self.connection.http_client.request( + HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers + ) + + self._handle_staging_http_response(r) + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None @@ -840,6 +892,7 @@ def execute( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + input_stream: Optional[BinaryIO] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -914,7 +967,8 @@ def execute( if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.connection.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path, + input_stream=input_stream, ) return self diff --git a/tests/e2e/common/streaming_put_tests.py b/tests/e2e/common/streaming_put_tests.py new file mode 100644 index 00000000..83da10fd --- /dev/null +++ b/tests/e2e/common/streaming_put_tests.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +E2E tests for streaming PUT operations. +""" + +import io +import logging +import pytest + +logger = logging.getLogger(__name__) + + +class PySQLStreamingPutTestSuiteMixin: + """Test suite for streaming PUT operations.""" + + def test_streaming_put_basic(self, catalog, schema): + """Test basic streaming PUT functionality.""" + + # Create test data + test_data = b"Hello, streaming world! This is test data." + filename = "streaming_put_test.txt" + file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}" + + try: + with self.connection() as conn: + with conn.cursor() as cursor: + self._cleanup_test_file(file_path) + + with io.BytesIO(test_data) as stream: + cursor.execute( + f"PUT '__input_stream__' INTO '{file_path}'", + input_stream=stream + ) + + # Verify file exists + cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'") + files = cursor.fetchall() + + # Check if our file is in the list + file_paths = [row[0] for row in files] + assert file_path in file_paths, f"File {file_path} not found in {file_paths}" + finally: + self._cleanup_test_file(file_path) + + def test_streaming_put_missing_stream(self, catalog, schema): + """Test that missing stream raises appropriate error.""" + + with self.connection() as conn: + with conn.cursor() as cursor: + # Test without providing stream + with pytest.raises(Exception): # Should fail + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'" + # Note: No input_stream parameter + ) + + def _cleanup_test_file(self, file_path): + """Clean up a test file if it exists.""" + try: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with conn.cursor() as cursor: + cursor.execute(f"REMOVE '{file_path}'") + logger.info("Successfully cleaned up test file: %s", file_path) + except Exception as e: + logger.error("Cleanup failed for %s: %s", file_path, e) \ No newline at end of file diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 53b7383e..52f6e4a2 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,6 +50,7 @@ from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin +from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin from databricks.sql.exc import SessionAlreadyClosedError @@ -290,6 +291,7 @@ class TestPySQLCoreSuite( PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLUCVolumeTestSuiteMixin, + PySQLStreamingPutTestSuiteMixin, ): validate_row_value_type = True validate_result = True diff --git a/tests/unit/test_streaming_put.py b/tests/unit/test_streaming_put.py new file mode 100644 index 00000000..2b9a9e6d --- /dev/null +++ b/tests/unit/test_streaming_put.py @@ -0,0 +1,113 @@ +import io +from unittest.mock import patch, Mock, MagicMock + +import pytest + +import databricks.sql.client as client + + +class TestStreamingPut: + """Unit tests for streaming PUT functionality.""" + + @pytest.fixture + def cursor(self): + return client.Cursor(connection=Mock(), backend=Mock()) + + def _setup_mock_staging_put_stream_response(self, mock_backend): + """Helper method to set up mock staging PUT stream response.""" + mock_result_set = Mock() + mock_result_set.is_staging_operation = True + mock_backend.execute_command.return_value = mock_result_set + + mock_row = Mock() + mock_row.operation = "PUT" + mock_row.localFile = "__input_stream__" + mock_row.presignedUrl = "https://example.com/upload" + mock_row.headers = "{}" + mock_result_set.fetchone.return_value = mock_row + + return mock_result_set + + def test_execute_with_valid_stream(self, cursor): + """Test execute method with valid input stream.""" + + # Mock the backend response + self._setup_mock_staging_put_stream_response(cursor.backend) + + # Test with valid stream + test_stream = io.BytesIO(b"test data") + + with patch.object(cursor, "_handle_staging_put_stream") as mock_handler: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=test_stream, + ) + + # Verify staging handler was called + mock_handler.assert_called_once() + + def test_execute_with_none_stream_for_staging_put(self, cursor): + """Test execute method rejects None stream for streaming PUT operations.""" + + # Mock staging operation response for None case + self._setup_mock_staging_put_stream_response(cursor.backend) + + # None with __input_stream__ raises ProgrammingError + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None, + ) + error_msg = str(excinfo.value) + assert "No input stream provided for streaming operation" in error_msg + + def test_handle_staging_put_stream_success(self, cursor): + """Test successful streaming PUT operation.""" + + presigned_url = "https://example.com/upload" + headers = {"Content-Type": "text/plain"} + + with patch.object( + cursor.connection.http_client, "request" + ) as mock_http_request: + mock_response = MagicMock() + mock_response.status = 200 + mock_response.data = b"success" + mock_http_request.return_value = mock_response + + test_stream = io.BytesIO(b"test data") + cursor._handle_staging_put_stream( + presigned_url=presigned_url, stream=test_stream, headers=headers + ) + + # Verify the HTTP client was called correctly + mock_http_request.assert_called_once() + call_args = mock_http_request.call_args + # Check positional arguments: (method, url, body=..., headers=...) + assert call_args[0][0].value == "PUT" # First positional arg is method + assert call_args[0][1] == presigned_url # Second positional arg is url + # Check keyword arguments + assert call_args[1]["body"] == b"test data" + assert call_args[1]["headers"] == headers + + def test_handle_staging_put_stream_http_error(self, cursor): + """Test streaming PUT operation with HTTP error.""" + + presigned_url = "https://example.com/upload" + + with patch.object( + cursor.connection.http_client, "request" + ) as mock_http_request: + mock_response = MagicMock() + mock_response.status = 500 + mock_response.data = b"Internal Server Error" + mock_http_request.return_value = mock_response + + test_stream = io.BytesIO(b"test data") + with pytest.raises(client.OperationalError) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, stream=test_stream + ) + + # Check for the actual error message format + assert "500" in str(excinfo.value)