diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 02c88aa6..d2b10e71 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -735,7 +735,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, session_id_hex=None): + def _col_to_description(col, field=None, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -764,12 +764,39 @@ def _col_to_description(col, session_id_hex=None): else: precision, scale = None, None + # Extract variant type from field if available + if field is not None: + try: + # Check for variant type in metadata + if field.metadata and b"Spark:DataType:SqlName" in field.metadata: + sql_type = field.metadata.get(b"Spark:DataType:SqlName") + if sql_type == b"VARIANT": + cleaned_type = "variant" + except Exception as e: + logger.debug(f"Could not extract variant type from field: {e}") + return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, session_id_hex=None): + def _hive_schema_to_description( + t_table_schema, schema_bytes=None, session_id_hex=None + ): + field_dict = {} + if pyarrow and schema_bytes: + try: + arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) + # Build a dictionary mapping column names to fields + for field in arrow_schema: + field_dict[field.name] = field + except Exception as e: + logger.debug(f"Could not parse arrow schema: {e}") + return [ - ThriftDatabricksClient._col_to_description(col, session_id_hex) + ThriftDatabricksClient._col_to_description( + col, + field_dict.get(col.columnName) if field_dict else None, + session_id_hex, + ) for col in t_table_schema.columns ] @@ -802,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, - self._session_id_hex, - ) - if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema @@ -819,6 +841,12 @@ def _results_message_to_execute_response(self, resp, operation_state): else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, + schema_bytes, + self._session_id_hex, + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -863,11 +891,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, - self._session_id_hex, - ) - if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema @@ -880,6 +903,12 @@ def get_execution_result( else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, + schema_bytes, + self._session_id_hex, + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py new file mode 100644 index 00000000..b5dc1f42 --- /dev/null +++ b/tests/e2e/test_variant_types.py @@ -0,0 +1,91 @@ +import pytest +from datetime import datetime +import json + +try: + import pyarrow +except ImportError: + pyarrow = None + +from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow + + +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support") +class TestVariantTypes(PySQLPytestTestCase): + """Tests for the proper detection and handling of VARIANT type columns""" + + @pytest.fixture(scope="class") + def variant_table(self, connection_details): + """A pytest fixture that creates a test table and cleans up after tests""" + self.arguments = connection_details.copy() + table_name = "pysql_test_variant_types_table" + + with self.cursor() as cursor: + try: + # Create the table with variant columns + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table ( + id INTEGER, + variant_col VARIANT, + regular_string_col STRING + ) + """ + ) + + # Insert test records with different variant values + cursor.execute( + """ + INSERT INTO pysql_test_variant_types_table + VALUES + (1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'), + (2, PARSE_JSON('[1, 2, 3, 4]'), 'another string') + """ + ) + yield table_name + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + + def test_variant_type_detection(self, variant_table): + """Test that VARIANT type columns are properly detected in schema""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0") + + # Verify column types in description + assert ( + cursor.description[0][1] == "int" + ), "Integer column type not correctly identified" + assert ( + cursor.description[1][1] == "variant" + ), "VARIANT column type not correctly identified" + assert ( + cursor.description[2][1] == "string" + ), "String column type not correctly identified" + + def test_variant_data_retrieval(self, variant_table): + """Test that VARIANT data is properly retrieved and can be accessed as JSON""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id") + rows = cursor.fetchall() + + # First row should have a JSON object + json_obj = rows[0][1] + assert isinstance( + json_obj, str + ), "VARIANT column should be returned as string" + + parsed = json.loads(json_obj) + assert parsed.get("name") == "John" + assert parsed.get("age") == 30 + + # Second row should have a JSON array + json_array = rows[1][1] + assert isinstance( + json_array, str + ), "VARIANT array should be returned as string" + + # Parsing to verify it's valid JSON array + parsed_array = json.loads(json_array) + assert isinstance(parsed_array, list) + assert parsed_array == [1, 2, 3, 4] diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0445ace3..7254b66c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( @@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow ) + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_col_to_description(self): + test_cases = [ + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"), + ("normal_col", {}, "string"), + ("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"), + ("missing_field", None, "string"), # None field case + ] + + for column_name, field_metadata, expected_type in test_cases: + with self.subTest(column_name=column_name, expected_type=expected_type): + col = ttypes.TColumnDesc( + columnName=column_name, + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + field = ( + None + if field_metadata is None + else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata) + ) + + result = ThriftDatabricksClient._col_to_description(col, field) + + self.assertEqual(result[0], column_name) + self.assertEqual(result[1], expected_type) + self.assertIsNone(result[2]) + self.assertIsNone(result[3]) + self.assertIsNone(result[4]) + self.assertIsNone(result[5]) + self.assertIsNone(result[6]) + + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_hive_schema_to_description(self): + test_cases = [ + ( + [ + ("regular_col", ttypes.TTypeId.STRING_TYPE), + ("variant_col", ttypes.TTypeId.STRING_TYPE), + ], + [ + ("regular_col", {}), + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}), + ], + [("regular_col", "string"), ("variant_col", "variant")], + ), + ( + [("regular_col", ttypes.TTypeId.STRING_TYPE)], + None, # No arrow schema + [("regular_col", "string")], + ), + ] + + for columns, arrow_fields, expected_types in test_cases: + with self.subTest(arrow_fields=arrow_fields is not None): + t_table_schema = ttypes.TTableSchema( + columns=[ + ttypes.TColumnDesc( + columnName=name, typeDesc=self._make_type_desc(col_type) + ) + for name, col_type in columns + ] + ) + + schema_bytes = None + if arrow_fields: + fields = [ + pyarrow.field(name, pyarrow.string(), metadata=metadata) + for name, metadata in arrow_fields + ] + schema_bytes = pyarrow.schema(fields).serialize().to_pybytes() + + description = ThriftDatabricksClient._hive_schema_to_description( + t_table_schema, schema_bytes + ) + + for i, (expected_name, expected_type) in enumerate(expected_types): + self.assertEqual(description[i][0], expected_name) + self.assertEqual(description[i][1], expected_type) + if __name__ == "__main__": unittest.main()