-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs #52140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs #52140
Conversation
""" | ||
|
||
def __init__(self, table_arg_offsets=None): | ||
def __init__(self, table_arg_offsets=None, arrow_cast=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default value should be True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep! I changed it to True and add a SQLConf to gate it
yield pa.Table.from_struct_array(pa.array([{}] * 3)) | ||
|
||
assertDataFrameEqual(EmptyResultUDTF(), [Row(), Row(), Row()]) | ||
assertDataFrameEqual(EmptyResultUDTF(), [None, None, None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this change is unexpected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I have reverted it and create an empty batch with the number of rows set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for supporting this!
if batch.num_columns == 0: | ||
# When batch has no column, it should still create | ||
# an empty batch with the number of rows set. | ||
struct = pa.array([{}] * batch.num_rows) | ||
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to handle this case? cc @ueshin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to ensure the test case "test_arrow_udtf_with_empty_column_result" to work. Please refer to #52140 (comment) comment for the unexpected behavior change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this will be done in super().dump_stream()
, too?
""" | ||
|
||
def __init__(self, table_arg_offsets=None): | ||
def __init__(self, table_arg_offsets=None, arrow_cast=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's enable arrow_cast by default for ArrowUDTFs (it's a new feature) so we don't need a flag here.
if arr.type == arrow_type: | ||
return arr | ||
elif self._arrow_cast: | ||
return arr.cast(target_type=arrow_type, safe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between safe=True vs False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will only allow casts that are guaranteed not to lose information. Truncation (floats to ints), narrowing (int64 → int8), or precision loss are not allowed. Will add a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zhengruifeng is this the same behavior as Arrow UDFs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arrow UDF has the same implementation. cc: @zhengruifeng please keep me honest.
Update: we now change it to RecordBatch.cast
, which should have same behavior as arr.cast but is more performant. cc: @ueshin
assert isinstance( | ||
batch, pa.RecordBatch | ||
), f"Expected pa.RecordBatch, got {type(batch)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we already check this worker.py so no need to duplicate this check :)
raise PySparkRuntimeError( | ||
errorClass="UDTF_RETURN_SCHEMA_MISMATCH", | ||
messageParameters={ | ||
"expected": str(len(arrow_return_type)), | ||
"actual": str(batch.num_columns), | ||
"func": "ArrowUDTF", | ||
}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. I think we already checked if the return column mismatch the expected return schema in worker.py. Would you mind double check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean verify_arrow_result
in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type
in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type)))
.
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The column length is checked before it. Please take a look at:
if result.num_columns != return_type_size:
...
in verify_result
.
if arr.type == arrow_type: | ||
return arr | ||
elif self._arrow_cast: | ||
return arr.cast(target_type=arrow_type, safe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, it would be great to list he type coercion rule here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
result_df = MismatchedSchemaUDTF() | ||
result_df.collect() | ||
if self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower() == "false": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use with self.sql_conf("...")
val PYTHON_TABLE_UDF_TYPE_CORERION_ENABLED = | ||
buildConf("spark.sql.execution.pythonUDTF.typeCoercion.enabled") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's enable Arrow cast for Arrow Python UDTFs by default so we don't need this config :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, on it
Update: done
{ | ||
"wrong_col": pa.array([1], type=pa.int32()), | ||
"another_wrong_col": pa.array([2.5], type=pa.float64()), | ||
"col_with_arrow_cast": pa.array([1], type=pa.int32()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we have input to be int64
and output to be int32
? Does arrow cast throw exception in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will. We had a test case "test_return_type_coercion_overflow"
if arr.type == arrow_type: | ||
return arr | ||
elif self._arrow_cast: | ||
return arr.cast(target_type=arrow_type, safe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zhengruifeng is this the same behavior as Arrow UDFs?
result_df = MismatchedSchemaUDTF() | ||
result_df.collect() | ||
else: | ||
with self.assertRaisesRegex(PythonException, "Failed to parse string: 'wrong_col' as a scalar of type int32"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm looks like without arrow cast, the error message looks better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a try-catch block to polish the error message with the arrow cast
for packed in iterator: | ||
batch, arrow_return_type = packed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
for packed in iterator: | |
batch, arrow_return_type = packed | |
for batch, arrow_return_type in iterator: |
.booleanConf | ||
.createWithDefault(false) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: revert this?
python/pyspark/worker.py
Outdated
ser = ArrowStreamArrowUDTFSerializer( | ||
table_arg_offsets=table_arg_offsets | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like an unnecessary change?
pass | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
Could you run:
./dev/reformat-python
to make the linter happy?
raise PySparkRuntimeError( | ||
errorClass="UDTF_RETURN_SCHEMA_MISMATCH", | ||
messageParameters={ | ||
"expected": str(len(arrow_return_type)), | ||
"actual": str(batch.num_columns), | ||
"func": "ArrowUDTF", | ||
}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The column length is checked before it. Please take a look at:
if result.num_columns != return_type_size:
...
in verify_result
.
if should_write_start_length: | ||
write_int(SpecialLengths.START_ARROW_STREAM, stream) | ||
should_write_start_length = False | ||
|
||
yield coerced_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are done in the super().dump_stream()
. What we should do here is just type-casting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just wondering whether we can use RecordBatch.cast
for this instead of casting each column?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense! I changed it to RecordBatch.cast
|
||
yield coerced_batch, arrow_return_type | ||
|
||
return super(ArrowStreamArrowUDTFSerializer, self).dump_stream( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: super().dump_stream ...
if expected_type and actual_type: | ||
error_msg = f"Expected: {expected_type}, but got: {actual_type} in field '{expected_field.name}'." | ||
else: | ||
error_msg = f"Expected: {target_schema}, but got: {batch.schema}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this case is enough as an error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allison was asking for a better error message :)
#52140 (comment)
error_msg = f"Expected: {target_schema}, but got: {batch.schema}." | ||
|
||
raise PySparkTypeError( | ||
"Arrow UDTFs require the return type to match the expected Arrow type." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ... Arrow type. "
to have a space between this and error_msg
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending tests.
if batch.num_columns == 0: | ||
coerced_batch = batch # skip type coercion | ||
else: | ||
expected_field_names = [field.name for field in arrow_return_type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can use arrow_return_type.names
instead?
coerced_array = self._create_array(original_array, field.type) | ||
coerced_arrays.append(coerced_array) | ||
coerced_batch = pa.RecordBatch.from_arrays( | ||
coerced_arrays, names=arrow_return_type.names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: expected_field_names
or actual_field_names
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are same here :)
result_table = pa.table( | ||
{ | ||
"id": pa.array(["abc", "def", "xyz"], type=pa.string()), | ||
"id": pa.array(["1", "2", "xyz"], type=pa.string()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work if it's pa.array(["1", "2", "3"]
? Shall we have the test to confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Casting from "1" to 1 should work. I added a test_arrow_udtf_type_coercion_string_to_int_safe.
raise PySparkTypeError( | ||
"Arrow UDTFs require the return type to match the expected Arrow type. " | ||
f"Expected: {arrow_type}, but got: {arr.type}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we use an error class here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added :) The exception thrown is not structured though. I use assertRaisesRegex
to check the error
An exception was thrown from the Python worker. Please see the stack trace below.
expected_field_names = arrow_return_type.names | ||
actual_field_names = batch.schema.names | ||
|
||
if expected_field_names != actual_field_names: | ||
raise PySparkTypeError( | ||
"Target schema's field names are not matching the record batch's " | ||
"field names. " | ||
f"Expected: {expected_field_names}, but got: {actual_field_names}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm we didn't check this in verify_result? It's a much better error message than before but we should take a look on how to merge this with verify_result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type
in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type)))
.
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I refactor verify_arrow_result
in a follow-up PR and integrate the results verification to unblock this PR?
The current verify_arrow_result
is primarily for Arrow UDFs
original_array = batch.column(i) | ||
coerced_array = self._create_array(original_array, field.type) | ||
coerced_arrays.append(coerced_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we directly use batch.cast(arrow_return_type)
? and the default safe
parameter should be True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed this with @ueshin offline. Unfortunately, RecordBatch.cast isn’t available in the currently required minimal PyArrow version. We’ll need to bump the minimum requirement to support it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for supporting this!
}, | ||
"RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF": { | ||
"message": [ | ||
"Column names of the returned pyarrow.Table do not match specified schema. Expected: <expected> Actual: <actual>" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessarily pyarrow.Table (it can be columnar batch). How about let's just say, Column names of the returned table do not match ...
original_array = batch.column(i) | ||
coerced_array = self._create_array(original_array, field.type) | ||
coerced_arrays.append(coerced_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it?
# when safe is True, the cast will fail if there's a overflow or other | ||
# unsafe conversion. | ||
# RecordBatch.cast(...) isn't used as minimum PyArrow version | ||
# required for RecordBatch.cast(...) is v21.0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RecordBatch.cast is available since 16.0.
https://arrow.apache.org/docs/16.0/python/generated/pyarrow.RecordBatch.html#pyarrow.RecordBatch.cast
Thanks! merging to master. |
What changes were proposed in this pull request?
Support return type coercion for Arrow Python UDTFs by doing
arrow_cast
by defaultWhy are the changes needed?
Consistent behavior across Arrow UDFs and Arrow UDTFs
Does this PR introduce any user-facing change?
No, Arrow UDTF is not a public API yet
How was this patch tested?
New and existing UTs
Was this patch authored or co-authored using generative AI tooling?
No