diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc89..2a9b8b41c 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import Any, List, Optional, TYPE_CHECKING import logging @@ -26,6 +27,8 @@ class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" + backend: SeaDatabricksClient + def __init__( self, connection: Connection, @@ -82,6 +85,43 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + def _convert_complex_types_to_string( + self, rows: "pyarrow.Table" + ) -> "pyarrow.Table": + """ + Convert complex types (array, struct, map) to string representation. + Args: + rows: Input PyArrow table + Returns: + PyArrow table with complex types converted to strings + """ + + if not pyarrow: + raise ImportError( + "PyArrow is not installed: _use_arrow_native_complex_types = False requires pyarrow" + ) + + def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array": + python_values = col.to_pylist() + json_strings = [ + (json.dumps(val) if val is not None else None) for val in python_values + ] + return pyarrow.array(json_strings, type=pyarrow.string()) + + converted_columns = [] + for col in rows.columns: + converted_col = col + if ( + pyarrow.types.is_list(col.type) + or pyarrow.types.is_large_list(col.type) + or pyarrow.types.is_struct(col.type) + or pyarrow.types.is_map(col.type) + ): + converted_col = convert_complex_column_to_string(col) + converted_columns.append(converted_col) + + return pyarrow.Table.from_arrays(converted_columns, names=rows.column_names) + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -199,6 +239,9 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + if not self.backend._use_arrow_native_complex_types: + results = self._convert_complex_types_to_string(results) + self._next_row_index += results.num_rows return results @@ -212,6 +255,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + if not self.backend._use_arrow_native_complex_types: + results = self._convert_complex_types_to_string(results) + self._next_row_index += results.num_rows return results diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index 212ddf916..9b64b8794 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -56,10 +56,19 @@ def table_fixture(self, connection_details): ("map_array_col", list), ], ) - def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_read_complex_types_as_arrow( + self, field, expected_type, table_fixture, backend_params + ): """Confirms the return types of a complex type field when reading as arrow""" - with self.cursor() as cursor: + with self.cursor(extra_params=backend_params) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() @@ -77,11 +86,17 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): ("map_array_col"), ], ) - def test_read_complex_types_as_string(self, field, table_fixture): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_read_complex_types_as_string(self, field, table_fixture, backend_params): """Confirms the return type of a complex type that is returned as a string""" - with self.cursor( - extra_params={"_use_arrow_native_complex_types": False} - ) as cursor: + extra_params = {**backend_params, "_use_arrow_native_complex_types": False} + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone()