-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs #52140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
83ad11a
f27409a
25a9742
f6ff4c7
481a3a5
6df8f57
e29f5d1
44a46b0
505db61
65fa7a1
e49e59b
0b5337b
4ed23a4
8174be8
b1602fb
5fcf52e
6760de6
3550ae6
02c9592
4586fc3
581fcbb
ecafd74
1c4d5d5
0d6755a
7abebc6
10a28cd
2c01a3b
fee5c23
ff33015
5c292ba
17923b0
3a97fa6
f6771ba
c90574d
97ac86c
98f97c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,6 +227,68 @@ def load_stream(self, stream): | |
result_batches.append(batch.column(i)) | ||
yield result_batches | ||
|
||
def _create_array(self, arr, arrow_type): | ||
import pyarrow as pa | ||
|
||
assert isinstance(arr, pa.Array) | ||
assert isinstance(arrow_type, pa.DataType) | ||
if arr.type == arrow_type: | ||
return arr | ||
else: | ||
try: | ||
# when safe is True, the cast will fail if there's a overflow or other | ||
# unsafe conversion. | ||
# RecordBatch.cast(...) isn't used as minimum PyArrow version | ||
# required for RecordBatch.cast(...) is v16.0 | ||
return arr.cast(target_type=arrow_type, safe=True) | ||
except (pa.ArrowInvalid, pa.ArrowTypeError): | ||
raise PySparkRuntimeError( | ||
errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", | ||
messageParameters={ | ||
"expected": str(arrow_type), | ||
"actual": str(arr.type), | ||
}, | ||
) | ||
|
||
def dump_stream(self, iterator, stream): | ||
""" | ||
Override to handle type coercion for ArrowUDTF outputs. | ||
ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. | ||
""" | ||
import pyarrow as pa | ||
|
||
def apply_type_coercion(): | ||
for batch, arrow_return_type in iterator: | ||
assert isinstance( | ||
arrow_return_type, pa.StructType | ||
), f"Expected pa.StructType, got {type(arrow_return_type)}" | ||
|
||
# Handle empty struct case specially | ||
if batch.num_columns == 0: | ||
coerced_batch = batch # skip type coercion | ||
else: | ||
expected_field_names = arrow_return_type.names | ||
actual_field_names = batch.schema.names | ||
|
||
if expected_field_names != actual_field_names: | ||
raise PySparkTypeError( | ||
"Target schema's field names are not matching the record batch's " | ||
"field names. " | ||
f"Expected: {expected_field_names}, but got: {actual_field_names}." | ||
) | ||
|
||
coerced_arrays = [] | ||
for i, field in enumerate(arrow_return_type): | ||
original_array = batch.column(i) | ||
coerced_array = self._create_array(original_array, field.type) | ||
coerced_arrays.append(coerced_array) | ||
Comment on lines
+282
to
+284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we directly use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I discussed this with @ueshin offline. Unfortunately, RecordBatch.cast isn’t available in the currently required minimal PyArrow version. We’ll need to bump the minimum requirement to support it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it? |
||
coerced_batch = pa.RecordBatch.from_arrays( | ||
coerced_arrays, names=arrow_return_type.names | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. they are same here :) |
||
) | ||
yield coerced_batch, arrow_return_type | ||
|
||
return super().dump_stream(apply_type_coercion(), stream) | ||
|
||
|
||
class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,7 +189,10 @@ def eval(self) -> Iterator["pa.Table"]: | |
) | ||
yield result_table | ||
|
||
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
with self.assertRaisesRegex( | ||
PythonException, | ||
"Target schema's field names are not matching the record batch's field names", | ||
): | ||
result_df = MismatchedSchemaUDTF() | ||
result_df.collect() | ||
|
||
|
@@ -330,9 +333,10 @@ def eval(self) -> Iterator["pa.Table"]: | |
) | ||
yield result_table | ||
|
||
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
result_df = LongToIntUDTF() | ||
result_df.collect() | ||
# Should succeed with automatic coercion | ||
result_df = LongToIntUDTF() | ||
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_arrow_udtf_type_coercion_string_to_int(self): | ||
@arrow_udtf(returnType="id int") | ||
|
@@ -341,15 +345,103 @@ def eval(self) -> Iterator["pa.Table"]: | |
# Return string values that cannot be coerced to int | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array(["abc", "def", "xyz"], type=pa.string()), | ||
"id": pa.array(["1", "2", "xyz"], type=pa.string()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work if it's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casting from "1" to 1 should work. I added a test_arrow_udtf_type_coercion_string_to_int_safe. |
||
} | ||
) | ||
yield result_table | ||
|
||
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
# Should fail with Arrow cast exception since string cannot be cast to int | ||
with self.assertRaisesRegex( | ||
PythonException, | ||
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF\\] " | ||
"Column names of the returned pyarrow.Table or pyarrow.RecordBatch do not match " | ||
"specified schema. Expected: int32 Actual: string", | ||
): | ||
result_df = StringToIntUDTF() | ||
result_df.collect() | ||
|
||
def test_arrow_udtf_type_coercion_string_to_int_safe(self): | ||
@arrow_udtf(returnType="id int") | ||
class StringToIntUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array(["1", "2", "3"], type=pa.string()), | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = StringToIntUDTF() | ||
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_arrow_udtf_type_corecion_int64_to_int32_safe(self): | ||
@arrow_udtf(returnType="id int") | ||
class Int64ToInt32UDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array([1, 2, 3], type=pa.int64()), # long values | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = Int64ToInt32UDTF() | ||
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_return_type_coercion_success(self): | ||
@arrow_udtf(returnType="value int") | ||
class CoercionSuccessUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"value": pa.array([10, 20, 30], type=pa.int64()), # long -> int coercion | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = CoercionSuccessUDTF() | ||
expected_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_return_type_coercion_overflow(self): | ||
@arrow_udtf(returnType="value int") | ||
class CoercionOverflowUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
# Return values that will cause overflow when casting long to int | ||
result_table = pa.table( | ||
{ | ||
"value": pa.array([2147483647 + 1], type=pa.int64()), # int32 max + 1 | ||
} | ||
) | ||
yield result_table | ||
|
||
# Should fail with PyArrow overflow exception | ||
with self.assertRaises(Exception): | ||
result_df = CoercionOverflowUDTF() | ||
result_df.collect() | ||
|
||
def test_return_type_coercion_multiple_columns(self): | ||
@arrow_udtf(returnType="id int, price float") | ||
class MultipleColumnCoercionUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array([1, 2, 3], type=pa.int64()), # long -> int coercion | ||
"price": pa.array( | ||
[10.5, 20.7, 30.9], type=pa.float64() | ||
), # double -> float coercion | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = MultipleColumnCoercionUDTF() | ||
expected_df = self.spark.createDataFrame( | ||
[(1, 10.5), (2, 20.7), (3, 30.9)], "id int, price float" | ||
) | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_arrow_udtf_with_empty_column_result(self): | ||
@arrow_udtf(returnType=StructType()) | ||
class EmptyResultUDTF: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm we didn't check this in verify_result? It's a much better error message than before but we should take a look on how to merge this with verify_result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match
arrow_return_type
in the conversion ofpa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type)))
.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I refactor
verify_arrow_result
in a follow-up PR and integrate the results verification to unblock this PR?The current
verify_arrow_result
is primarily for Arrow UDFs