Skip to content

Commit e252d36

Browse files
committed
streaming put supoprt
Signed-off-by: Sreekanth Vadigi <[email protected]>
1 parent 87fed36 commit e252d36

File tree

5 files changed

+324
-14
lines changed

5 files changed

+324
-14
lines changed

examples/streaming_put.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple example of streaming PUT operations.
4+
5+
This demonstrates the basic usage of streaming PUT with the __input_stream__ token.
6+
"""
7+
8+
import io
9+
import os
10+
from databricks import sql
11+
12+
with sql.connect(
13+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
14+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
15+
access_token=os.getenv("DATABRICKS_TOKEN"),
16+
) as connection:
17+
18+
with connection.cursor() as cursor:
19+
# Create a simple data stream
20+
data = b"Hello, streaming world!"
21+
stream = io.BytesIO(data)
22+
23+
# Get catalog, schema, and volume from environment variables
24+
catalog = os.getenv("DATABRICKS_CATALOG")
25+
schema = os.getenv("DATABRICKS_SCHEMA")
26+
volume = os.getenv("DATABRICKS_VOLUME")
27+
28+
# Upload to Unity Catalog volume
29+
cursor.execute(
30+
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE",
31+
input_stream=stream
32+
)
33+
34+
print("File uploaded successfully!")

src/databricks/sql/client.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
2+
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO
33
import pandas
44

55
try:
@@ -662,7 +662,9 @@ def _check_not_closed(self):
662662
)
663663

664664
def _handle_staging_operation(
665-
self, staging_allowed_local_path: Union[None, str, List[str]]
665+
self,
666+
staging_allowed_local_path: Union[None, str, List[str]],
667+
input_stream: Optional[BinaryIO] = None,
666668
):
667669
"""Fetch the HTTP request instruction from a staging ingestion command
668670
and call the designated handler.
@@ -671,6 +673,28 @@ def _handle_staging_operation(
671673
is not descended from staging_allowed_local_path.
672674
"""
673675

676+
assert self.active_result_set is not None
677+
row = self.active_result_set.fetchone()
678+
assert row is not None
679+
680+
# May be real headers, or could be json string
681+
headers = (
682+
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
683+
)
684+
headers = dict(headers) if headers else {}
685+
686+
# Handle __input_stream__ token for PUT operations
687+
if (
688+
row.operation == "PUT"
689+
and getattr(row, "localFile", None) == "__input_stream__"
690+
):
691+
return self._handle_staging_put_stream(
692+
presigned_url=row.presignedUrl,
693+
stream=input_stream,
694+
headers=headers,
695+
)
696+
697+
# For non-streaming operations, validate staging_allowed_local_path
674698
if isinstance(staging_allowed_local_path, type(str())):
675699
_staging_allowed_local_paths = [staging_allowed_local_path]
676700
elif isinstance(staging_allowed_local_path, type(list())):
@@ -685,10 +709,6 @@ def _handle_staging_operation(
685709
os.path.abspath(i) for i in _staging_allowed_local_paths
686710
]
687711

688-
assert self.active_result_set is not None
689-
row = self.active_result_set.fetchone()
690-
assert row is not None
691-
692712
# Must set to None in cases where server response does not include localFile
693713
abs_localFile = None
694714

@@ -711,19 +731,16 @@ def _handle_staging_operation(
711731
session_id_hex=self.connection.get_session_id_hex(),
712732
)
713733

714-
# May be real headers, or could be json string
715-
headers = (
716-
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
717-
)
718-
719734
handler_args = {
720735
"presigned_url": row.presignedUrl,
721736
"local_file": abs_localFile,
722-
"headers": dict(headers) or {},
737+
"headers": headers,
723738
}
724739

725740
logger.debug(
726-
f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}"
741+
"Attempting staging operation indicated by server: %s - %s",
742+
row.operation,
743+
getattr(row, "localFile", ""),
727744
)
728745

729746
# TODO: Create a retry loop here to re-attempt if the request times out or fails
@@ -762,6 +779,10 @@ def _handle_staging_put(
762779
HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers
763780
)
764781

782+
self._handle_staging_http_response(r)
783+
784+
def _handle_staging_http_response(self, r):
785+
765786
# fmt: off
766787
# HTTP status codes
767788
OK = 200
@@ -784,6 +805,38 @@ def _handle_staging_put(
784805
+ "but not yet applied on the server. It's possible this command may fail later."
785806
)
786807

808+
@log_latency(StatementType.SQL)
809+
def _handle_staging_put_stream(
810+
self,
811+
presigned_url: str,
812+
stream: BinaryIO,
813+
headers: dict = {},
814+
) -> None:
815+
"""Handle PUT operation with streaming data.
816+
817+
Args:
818+
presigned_url: The presigned URL for upload
819+
stream: Binary stream to upload
820+
headers: HTTP headers
821+
822+
Raises:
823+
ProgrammingError: If no input stream is provided
824+
OperationalError: If the upload fails
825+
"""
826+
827+
if not stream:
828+
raise ProgrammingError(
829+
"No input stream provided for streaming operation",
830+
session_id_hex=self.connection.get_session_id_hex(),
831+
)
832+
833+
r = self.connection.http_client.request(
834+
HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers
835+
)
836+
837+
self._handle_staging_http_response(r)
838+
839+
787840
@log_latency(StatementType.SQL)
788841
def _handle_staging_get(
789842
self, local_file: str, presigned_url: str, headers: Optional[dict] = None
@@ -840,6 +893,7 @@ def execute(
840893
operation: str,
841894
parameters: Optional[TParameterCollection] = None,
842895
enforce_embedded_schema_correctness=False,
896+
input_stream: Optional[BinaryIO] = None,
843897
) -> "Cursor":
844898
"""
845899
Execute a query and wait for execution to complete.
@@ -914,7 +968,8 @@ def execute(
914968

915969
if self.active_result_set and self.active_result_set.is_staging_operation:
916970
self._handle_staging_operation(
917-
staging_allowed_local_path=self.connection.staging_allowed_local_path
971+
staging_allowed_local_path=self.connection.staging_allowed_local_path,
972+
input_stream=input_stream,
918973
)
919974

920975
return self
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python3
2+
"""
3+
E2E tests for streaming PUT operations.
4+
"""
5+
6+
import io
7+
import logging
8+
import pytest
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class PySQLStreamingPutTestSuiteMixin:
14+
"""Test suite for streaming PUT operations."""
15+
16+
def test_streaming_put_basic(self, catalog, schema):
17+
"""Test basic streaming PUT functionality."""
18+
19+
# Create test data
20+
test_data = b"Hello, streaming world! This is test data."
21+
filename = "streaming_put_test.txt"
22+
file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}"
23+
24+
try:
25+
with self.connection() as conn:
26+
with conn.cursor() as cursor:
27+
self._cleanup_test_file(file_path)
28+
29+
with io.BytesIO(test_data) as stream:
30+
cursor.execute(
31+
f"PUT '__input_stream__' INTO '{file_path}'",
32+
input_stream=stream
33+
)
34+
35+
# Verify file exists
36+
cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'")
37+
files = cursor.fetchall()
38+
39+
# Check if our file is in the list
40+
file_paths = [row[0] for row in files]
41+
assert file_path in file_paths, f"File {file_path} not found in {file_paths}"
42+
finally:
43+
self._cleanup_test_file(file_path)
44+
45+
def test_streaming_put_missing_stream(self, catalog, schema):
46+
"""Test that missing stream raises appropriate error."""
47+
48+
with self.connection() as conn:
49+
with conn.cursor() as cursor:
50+
# Test without providing stream
51+
with pytest.raises(Exception): # Should fail
52+
cursor.execute(
53+
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'"
54+
# Note: No input_stream parameter
55+
)
56+
57+
def _cleanup_test_file(self, file_path):
58+
"""Clean up a test file if it exists."""
59+
try:
60+
with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn:
61+
with conn.cursor() as cursor:
62+
cursor.execute(f"REMOVE '{file_path}'")
63+
logger.info("Successfully cleaned up test file: %s", file_path)
64+
except Exception as e:
65+
logger.error("Cleanup failed for %s: %s", file_path, e)

tests/e2e/test_driver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin
5151

5252
from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin
53+
from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin
5354

5455
from databricks.sql.exc import SessionAlreadyClosedError
5556

@@ -290,6 +291,7 @@ class TestPySQLCoreSuite(
290291
PySQLStagingIngestionTestSuiteMixin,
291292
PySQLRetryTestsMixin,
292293
PySQLUCVolumeTestSuiteMixin,
294+
PySQLStreamingPutTestSuiteMixin,
293295
):
294296
validate_row_value_type = True
295297
validate_result = True

0 commit comments

Comments
 (0)