Skip to content

Conversation

shujingyang-db
Copy link
Contributor

What changes were proposed in this pull request?

Support return type coercion for Arrow Python UDTFs by doing arrow_cast by default

Why are the changes needed?

Consistent behavior across Arrow UDFs and Arrow UDTFs

Does this PR introduce any user-facing change?

No, Arrow UDTF is not a public API yet

How was this patch tested?

New and existing UTs

Was this patch authored or co-authored using generative AI tooling?

No

zhengruifeng
zhengruifeng previously approved these changes Aug 27, 2025
"""

def __init__(self, table_arg_offsets=None):
def __init__(self, table_arg_offsets=None, arrow_cast=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value should be True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep! I changed it to True and add a SQLConf to gate it

yield pa.Table.from_struct_array(pa.array([{}] * 3))

assertDataFrameEqual(EmptyResultUDTF(), [Row(), Row(), Row()])
assertDataFrameEqual(EmptyResultUDTF(), [None, None, None])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this change is unexpected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have reverted it and create an empty batch with the number of rows set

@github-actions github-actions bot added the CORE label Aug 28, 2025
Copy link
Contributor

@allisonwang-db allisonwang-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for supporting this!

Comment on lines 267 to 271
if batch.num_columns == 0:
# When batch has no column, it should still create
# an empty batch with the number of rows set.
struct = pa.array([{}] * batch.num_rows)
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to handle this case? cc @ueshin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to ensure the test case "test_arrow_udtf_with_empty_column_result" to work. Please refer to #52140 (comment) comment for the unexpected behavior change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will be done in super().dump_stream(), too?

"""

def __init__(self, table_arg_offsets=None):
def __init__(self, table_arg_offsets=None, arrow_cast=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's enable arrow_cast by default for ArrowUDTFs (it's a new feature) so we don't need a flag here.

if arr.type == arrow_type:
return arr
elif self._arrow_cast:
return arr.cast(target_type=arrow_type, safe=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between safe=True vs False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will only allow casts that are guaranteed not to lose information. Truncation (floats to ints), narrowing (int64 → int8), or precision loss are not allowed. Will add a comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zhengruifeng is this the same behavior as Arrow UDFs?

Copy link
Contributor Author

@shujingyang-db shujingyang-db Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrow UDF has the same implementation. cc: @zhengruifeng please keep me honest.

Update: we now change it to RecordBatch.cast, which should have same behavior as arr.cast but is more performant. cc: @ueshin

Comment on lines 259 to 261
assert isinstance(
batch, pa.RecordBatch
), f"Expected pa.RecordBatch, got {type(batch)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already check this worker.py so no need to duplicate this check :)

Comment on lines 281 to 288
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(arrow_return_type)),
"actual": str(batch.num_columns),
"func": "ArrowUDTF",
},
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. I think we already checked if the return column mismatch the expected return schema in worker.py. Would you mind double check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))).

verify_arrow_result(
    pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
    assign_cols_by_name=False,
    expected_cols_and_types=[
        (col.name, to_arrow_type(col.dataType)) for col in return_type.fields
    ],
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column length is checked before it. Please take a look at:

if result.num_columns != return_type_size:
...

in verify_result.

if arr.type == arrow_type:
return arr
elif self._arrow_cast:
return arr.cast(target_type=arrow_type, safe=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it would be great to list he type coercion rule here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
result_df = MismatchedSchemaUDTF()
result_df.collect()
if self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower() == "false":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use with self.sql_conf("...")

Comment on lines 4006 to 4007
val PYTHON_TABLE_UDF_TYPE_CORERION_ENABLED =
buildConf("spark.sql.execution.pythonUDTF.typeCoercion.enabled")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's enable Arrow cast for Arrow Python UDTFs by default so we don't need this config :)

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, on it

Update: done

{
"wrong_col": pa.array([1], type=pa.int32()),
"another_wrong_col": pa.array([2.5], type=pa.float64()),
"col_with_arrow_cast": pa.array([1], type=pa.int32()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we have input to be int64 and output to be int32? Does arrow cast throw exception in this case?

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will. We had a test case "test_return_type_coercion_overflow"

if arr.type == arrow_type:
return arr
elif self._arrow_cast:
return arr.cast(target_type=arrow_type, safe=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zhengruifeng is this the same behavior as Arrow UDFs?

result_df = MismatchedSchemaUDTF()
result_df.collect()
else:
with self.assertRaisesRegex(PythonException, "Failed to parse string: 'wrong_col' as a scalar of type int32"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm looks like without arrow cast, the error message looks better.

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a try-catch block to polish the error message with the arrow cast

Comment on lines 257 to 258
for packed in iterator:
batch, arrow_return_type = packed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for packed in iterator:
batch, arrow_return_type = packed
for batch, arrow_return_type in iterator:

.booleanConf
.createWithDefault(false)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: revert this?

Comment on lines 1347 to 1349
ser = ArrowStreamArrowUDTFSerializer(
table_arg_offsets=table_arg_offsets
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like an unnecessary change?

pass



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Could you run:

./dev/reformat-python

to make the linter happy?

Comment on lines 281 to 288
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(arrow_return_type)),
"actual": str(batch.num_columns),
"func": "ArrowUDTF",
},
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column length is checked before it. Please take a look at:

if result.num_columns != return_type_size:
...

in verify_result.

Comment on lines 293 to 297
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False

yield coerced_batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are done in the super().dump_stream(). What we should do here is just type-casting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering whether we can use RecordBatch.cast for this instead of casting each column?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense! I changed it to RecordBatch.cast


yield coerced_batch, arrow_return_type

return super(ArrowStreamArrowUDTFSerializer, self).dump_stream(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: super().dump_stream ...

if expected_type and actual_type:
error_msg = f"Expected: {expected_type}, but got: {actual_type} in field '{expected_field.name}'."
else:
error_msg = f"Expected: {target_schema}, but got: {batch.schema}."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this case is enough as an error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allison was asking for a better error message :)
#52140 (comment)

error_msg = f"Expected: {target_schema}, but got: {batch.schema}."

raise PySparkTypeError(
"Arrow UDTFs require the return type to match the expected Arrow type."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ... Arrow type. " to have a space between this and error_msg.

Copy link
Member

@ueshin ueshin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending tests.

if batch.num_columns == 0:
coerced_batch = batch # skip type coercion
else:
expected_field_names = [field.name for field in arrow_return_type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can use arrow_return_type.names instead?

coerced_array = self._create_array(original_array, field.type)
coerced_arrays.append(coerced_array)
coerced_batch = pa.RecordBatch.from_arrays(
coerced_arrays, names=arrow_return_type.names
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expected_field_names or actual_field_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are same here :)

result_table = pa.table(
{
"id": pa.array(["abc", "def", "xyz"], type=pa.string()),
"id": pa.array(["1", "2", "xyz"], type=pa.string()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if it's pa.array(["1", "2", "3"]? Shall we have the test to confirm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting from "1" to 1 should work. I added a test_arrow_udtf_type_coercion_string_to_int_safe.

Comment on lines 243 to 245
raise PySparkTypeError(
"Arrow UDTFs require the return type to match the expected Arrow type. "
f"Expected: {arrow_type}, but got: {arr.type}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we use an error class here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added :) The exception thrown is not structured though. I use assertRaisesRegex to check the error

An exception was thrown from the Python worker. Please see the stack trace below.

Comment on lines +265 to +273
expected_field_names = arrow_return_type.names
actual_field_names = batch.schema.names

if expected_field_names != actual_field_names:
raise PySparkTypeError(
"Target schema's field names are not matching the record batch's "
"field names. "
f"Expected: {expected_field_names}, but got: {actual_field_names}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm we didn't check this in verify_result? It's a much better error message than before but we should take a look on how to merge this with verify_result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))).

verify_arrow_result(
    pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
    assign_cols_by_name=False,
    expected_cols_and_types=[
        (col.name, to_arrow_type(col.dataType)) for col in return_type.fields
    ],
)

Copy link
Contributor Author

@shujingyang-db shujingyang-db Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I refactor verify_arrow_result in a follow-up PR and integrate the results verification to unblock this PR?

The current verify_arrow_result is primarily for Arrow UDFs

Comment on lines +277 to +279
original_array = batch.column(i)
coerced_array = self._create_array(original_array, field.type)
coerced_arrays.append(coerced_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we directly use batch.cast(arrow_return_type)? and the default safe parameter should be True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this with @ueshin offline. Unfortunately, RecordBatch.cast isn’t available in the currently required minimal PyArrow version. We’ll need to bump the minimum requirement to support it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it?

Copy link
Contributor

@allisonwang-db allisonwang-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for supporting this!

},
"RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF": {
"message": [
"Column names of the returned pyarrow.Table do not match specified schema. Expected: <expected> Actual: <actual>"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessarily pyarrow.Table (it can be columnar batch). How about let's just say, Column names of the returned table do not match ...

Comment on lines +277 to +279
original_array = batch.column(i)
coerced_array = self._create_array(original_array, field.type)
coerced_arrays.append(coerced_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it?

# when safe is True, the cast will fail if there's a overflow or other
# unsafe conversion.
# RecordBatch.cast(...) isn't used as minimum PyArrow version
# required for RecordBatch.cast(...) is v21.0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin
Copy link
Member

ueshin commented Sep 8, 2025

Thanks! merging to master.

@ueshin ueshin closed this in d16e92d Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants