-
Notifications
You must be signed in to change notification settings - Fork 124
Enhance Arrow to Pandas conversion with type overrides and additional kwargs #579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
048af73
0b1b05b
647ed39
31b44d4
2f32c6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]: | |
# (True by default) | ||
# use_cloud_fetch | ||
# Enable use of cloud fetch to extract large query results in parallel via cloud storage | ||
# _arrow_pandas_type_override | ||
# Override the default pandas dtype mapping for Arrow types. | ||
# This is a dictionary of Arrow types to pandas dtypes. | ||
# _arrow_to_pandas_kwargs | ||
# Additional or modified arguments to pass to pandas.DataFrame constructor. | ||
|
||
logger.debug( | ||
"Connection.__init__(server_hostname=%s, http_path=%s)", | ||
|
@@ -229,6 +234,8 @@ def read(self) -> Optional[OAuthToken]: | |
self.port = kwargs.get("_port", 443) | ||
self.disable_pandas = kwargs.get("_disable_pandas", False) | ||
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) | ||
self._arrow_pandas_type_override = kwargs.get("_arrow_pandas_type_override", {}) | ||
self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) | ||
|
||
auth_provider = get_python_sql_connector_auth_provider( | ||
server_hostname, **kwargs | ||
|
@@ -1346,7 +1353,9 @@ def _convert_arrow_table(self, table): | |
# Need to use nullable types, as otherwise type can change when there are missing values. | ||
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types | ||
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html | ||
dtype_mapping = { | ||
DEFAULT_DTYPE_MAPPING: Dict[ | ||
pyarrow.DataType, pandas.api.extensions.ExtensionDtype | ||
] = { | ||
pyarrow.int8(): pandas.Int8Dtype(), | ||
pyarrow.int16(): pandas.Int16Dtype(), | ||
pyarrow.int32(): pandas.Int32Dtype(), | ||
|
@@ -1361,13 +1370,35 @@ def _convert_arrow_table(self, table): | |
pyarrow.string(): pandas.StringDtype(), | ||
} | ||
|
||
arrow_pandas_type_override = self.connection._arrow_pandas_type_override | ||
if not isinstance(arrow_pandas_type_override, dict): | ||
logger.debug( | ||
"_arrow_pandas_type_override on connection was not a dict, using default type mapping" | ||
) | ||
arrow_pandas_type_override = {} | ||
|
||
dtype_mapping = { | ||
**DEFAULT_DTYPE_MAPPING, | ||
**arrow_pandas_type_override, | ||
} | ||
|
||
to_pandas_kwargs: dict[str, Any] = { | ||
"types_mapper": dtype_mapping.get, | ||
"date_as_object": True, | ||
"timestamp_as_object": True, | ||
} | ||
|
||
arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs | ||
if isinstance(arrow_to_pandas_kwargs, dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same let it fail when the input format is incorrect. The default python interpreter error of type mismatch is enough |
||
to_pandas_kwargs.update(arrow_to_pandas_kwargs) | ||
else: | ||
logger.debug( | ||
"_arrow_to_pandas_kwargs on connection was not a dict, using default arguments" | ||
) | ||
|
||
# Need to rename columns, as the to_pandas function cannot handle duplicate column names | ||
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) | ||
df = table_renamed.to_pandas( | ||
types_mapper=dtype_mapping.get, | ||
date_as_object=True, | ||
timestamp_as_object=True, | ||
) | ||
df = table_renamed.to_pandas(**to_pandas_kwargs) | ||
|
||
res = df.to_numpy(na_value=None, dtype="object") | ||
return [ResultRow(*v) for v in res] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import pytest | ||
|
||
try: | ||
import pyarrow as pa | ||
except ImportError: | ||
pa = None | ||
import pandas | ||
import datetime | ||
import unittest | ||
from unittest.mock import MagicMock | ||
|
||
from databricks.sql.client import ResultSet, Connection, ExecuteResponse | ||
from databricks.sql.types import Row | ||
from databricks.sql.utils import ArrowQueue | ||
|
||
@pytest.mark.skipif(pa is None, reason="PyArrow is not installed") | ||
class ArrowConversionTests(unittest.TestCase): | ||
@staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this to use fixtures |
||
def mock_connection_static(): | ||
conn = MagicMock(spec=Connection) | ||
conn.disable_pandas = False | ||
conn._arrow_pandas_type_override = {} | ||
conn._arrow_to_pandas_kwargs = {} | ||
return conn | ||
|
||
@staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use fixtures or just normal functions, don't need static methods |
||
def sample_arrow_table_static(): | ||
data = [ | ||
pa.array([1, 2, 3], type=pa.int32()), | ||
pa.array(["a", "b", "c"], type=pa.string()), | ||
] | ||
schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())]) | ||
return pa.Table.from_arrays(data, schema=schema) | ||
|
||
@staticmethod | ||
def mock_thrift_backend_static(): | ||
sample_table = ArrowConversionTests.sample_arrow_table_static() | ||
tb = MagicMock() | ||
empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema] | ||
empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema) | ||
tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) | ||
return tb | ||
|
||
@staticmethod | ||
def mock_raw_execute_response_static(): | ||
er = MagicMock(spec=ExecuteResponse) | ||
er.description = [ | ||
("col_int", "int", None, None, None, None, None), | ||
("col_str", "string", None, None, None, None, None), | ||
] | ||
er.arrow_schema_bytes = None | ||
er.arrow_queue = None | ||
er.has_more_rows = False | ||
er.lz4_compressed = False | ||
er.command_handle = MagicMock() | ||
er.status = MagicMock() | ||
er.has_been_closed_server_side = False | ||
er.is_staging_operation = False | ||
return er | ||
|
||
def test_convert_arrow_table_default(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test_convert_arrow_table_deafult, test_convert_arrow_table_disable_pandas and test_convert_arrow_table_type_override are essentially the same test flow just with different arguments. Plz use pytest's parameterized tests for such tests where only arguments change |
||
mock_connection = ArrowConversionTests.mock_connection_static() | ||
sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() | ||
mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() | ||
mock_raw_execute_response = ( | ||
ArrowConversionTests.mock_raw_execute_response_static() | ||
) | ||
|
||
mock_raw_execute_response.arrow_queue = ArrowQueue( | ||
sample_arrow_table, sample_arrow_table.num_rows | ||
) | ||
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) | ||
result_one = rs.fetchone() | ||
self.assertIsInstance(result_one, Row) | ||
self.assertEqual(result_one.col_int, 1) | ||
self.assertEqual(result_one.col_str, "a") | ||
|
||
mock_raw_execute_response.arrow_queue = ArrowQueue( | ||
sample_arrow_table, sample_arrow_table.num_rows | ||
) | ||
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) | ||
result_all = rs.fetchall() | ||
self.assertEqual(len(result_all), 3) | ||
self.assertIsInstance(result_all[0], Row) | ||
self.assertEqual(result_all[0].col_int, 1) | ||
self.assertEqual(result_all[1].col_str, "b") | ||
|
||
def test_convert_arrow_table_disable_pandas(self): | ||
mock_connection = ArrowConversionTests.mock_connection_static() | ||
sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() | ||
mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() | ||
mock_raw_execute_response = ( | ||
ArrowConversionTests.mock_raw_execute_response_static() | ||
) | ||
|
||
mock_connection.disable_pandas = True | ||
mock_raw_execute_response.arrow_queue = ArrowQueue( | ||
sample_arrow_table, sample_arrow_table.num_rows | ||
) | ||
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) | ||
result = rs.fetchall() | ||
self.assertEqual(len(result), 3) | ||
self.assertIsInstance(result[0], Row) | ||
self.assertEqual(result[0].col_int, 1) | ||
self.assertEqual(result[0].col_str, "a") | ||
self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int) | ||
self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str) | ||
|
||
def test_convert_arrow_table_type_override(self): | ||
mock_connection = ArrowConversionTests.mock_connection_static() | ||
sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() | ||
mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() | ||
mock_raw_execute_response = ( | ||
ArrowConversionTests.mock_raw_execute_response_static() | ||
) | ||
|
||
mock_connection._arrow_pandas_type_override = { | ||
pa.int32(): pandas.Float64Dtype() | ||
} | ||
mock_raw_execute_response.arrow_queue = ArrowQueue( | ||
sample_arrow_table, sample_arrow_table.num_rows | ||
) | ||
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) | ||
result = rs.fetchall() | ||
self.assertEqual(len(result), 3) | ||
self.assertIsInstance(result[0].col_int, float) | ||
self.assertEqual(result[0].col_int, 1.0) | ||
self.assertEqual(result[0].col_str, "a") | ||
|
||
def test_convert_arrow_table_to_pandas_kwargs(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too much code duplication. Can you create a parameterized test, where in this |
||
mock_connection = ArrowConversionTests.mock_connection_static() | ||
mock_thrift_backend = ( | ||
ArrowConversionTests.mock_thrift_backend_static() | ||
) # Does not use sample_arrow_table | ||
mock_raw_execute_response = ( | ||
ArrowConversionTests.mock_raw_execute_response_static() | ||
) | ||
|
||
dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) | ||
ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC")) | ||
ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))]) | ||
ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema) | ||
|
||
mock_raw_execute_response.description = [ | ||
("col_ts", "timestamp", None, None, None, None, None) | ||
] | ||
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) | ||
|
||
# Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. | ||
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} | ||
rs_ts_true = ResultSet( | ||
mock_connection, mock_raw_execute_response, mock_thrift_backend | ||
) | ||
result_true = rs_ts_true.fetchall() | ||
self.assertEqual(len(result_true), 1) | ||
self.assertIsInstance(result_true[0].col_ts, datetime.datetime) | ||
|
||
# Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. | ||
mock_raw_execute_response.arrow_queue = ArrowQueue( | ||
ts_table, ts_table.num_rows | ||
) # Reset queue | ||
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} | ||
rs_ts_false = ResultSet( | ||
mock_connection, mock_raw_execute_response, mock_thrift_backend | ||
) | ||
result_false = rs_ts_false.fetchall() | ||
self.assertEqual(len(result_false), 1) | ||
self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp) | ||
|
||
# Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. | ||
mock_raw_execute_response.arrow_queue = ArrowQueue( | ||
ts_table, ts_table.num_rows | ||
) # Reset queue | ||
mock_connection._arrow_to_pandas_kwargs = {} | ||
rs_ts_default = ResultSet( | ||
mock_connection, mock_raw_execute_response, mock_thrift_backend | ||
) | ||
result_default = rs_ts_default.fetchall() | ||
self.assertEqual(len(result_default), 1) | ||
self.assertIsInstance(result_default[0].col_ts, datetime.datetime) | ||
|
||
|
||
if __name__ == "__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not needed, as we run tests using pytest. Also can you move everything to pytest and remove unittest |
||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if block is not needed, let it fail here itself. Don't want the user to give something incorrect and then everything works. This is a new change and nothing to be backward compatible.
just leave it at this -
arrow_pandas_type_override = self.connection._arrow_pandas_type_override