Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
83ad11a
init
shujingyang-db Aug 27, 2025
f27409a
ckp
shujingyang-db Aug 27, 2025
25a9742
revert changes
shujingyang-db Aug 27, 2025
f6ff4c7
tests
shujingyang-db Aug 27, 2025
481a3a5
lint
shujingyang-db Aug 27, 2025
6df8f57
default to true
shujingyang-db Aug 28, 2025
e29f5d1
add a sqlconf for spark.sql.execution.pythonUDTF.typeCoercion.enabled
shujingyang-db Aug 28, 2025
44a46b0
add udtfTypeCoercion
shujingyang-db Aug 28, 2025
505db61
handle empty rows
shujingyang-db Aug 28, 2025
65fa7a1
fix tests
shujingyang-db Aug 28, 2025
e49e59b
rm conf - ckp
shujingyang-db Aug 29, 2025
0b5337b
rm sql conf
shujingyang-db Aug 29, 2025
4ed23a4
polish error message and rm checks
shujingyang-db Aug 29, 2025
8174be8
add comments
shujingyang-db Aug 29, 2025
b1602fb
ckp
shujingyang-db Sep 2, 2025
5fcf52e
errorMsg
shujingyang-db Sep 2, 2025
6760de6
comments
shujingyang-db Sep 2, 2025
3550ae6
nit
shujingyang-db Sep 2, 2025
02c9592
clean up
shujingyang-db Sep 2, 2025
4586fc3
Merge remote-tracking branch 'spark/master' into arrow-udtf-type-core…
shujingyang-db Sep 2, 2025
581fcbb
clean up
shujingyang-db Sep 2, 2025
ecafd74
format + fix tests
shujingyang-db Sep 2, 2025
1c4d5d5
revert to _create_array
shujingyang-db Sep 3, 2025
0d6755a
lint
shujingyang-db Sep 3, 2025
7abebc6
nit
shujingyang-db Sep 3, 2025
10a28cd
test_arrow_udtf_type_coercion_string_to_int_safe
shujingyang-db Sep 3, 2025
2c01a3b
naming
shujingyang-db Sep 3, 2025
fee5c23
lint
shujingyang-db Sep 3, 2025
ff33015
fix lint
shujingyang-db Sep 3, 2025
5c292ba
error class
shujingyang-db Sep 3, 2025
17923b0
add error class
shujingyang-db Sep 4, 2025
3a97fa6
lint
shujingyang-db Sep 4, 2025
f6771ba
fix lint
shujingyang-db Sep 4, 2025
c90574d
commenst and error messages
shujingyang-db Sep 5, 2025
97ac86c
fix
shujingyang-db Sep 7, 2025
98f97c1
Update serializers.py
shujingyang-db Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,11 @@
"Column names of the returned pyarrow.Table do not match specified schema.<missing><extra>"
]
},
"RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF": {
"message": [
"Column names of the returned pyarrow.Table or pyarrow.RecordBatch do not match specified schema. Expected: <expected> Actual: <actual>"
]
},
"RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF": {
"message": [
"Column names of the returned pandas.DataFrame do not match specified schema.<missing><extra>"
Expand Down
62 changes: 62 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,68 @@ def load_stream(self, stream):
result_batches.append(batch.column(i))
yield result_batches

def _create_array(self, arr, arrow_type):
import pyarrow as pa

assert isinstance(arr, pa.Array)
assert isinstance(arrow_type, pa.DataType)
if arr.type == arrow_type:
return arr
else:
try:
# when safe is True, the cast will fail if there's a overflow or other
# unsafe conversion.
# RecordBatch.cast(...) isn't used as minimum PyArrow version
# required for RecordBatch.cast(...) is v16.0
return arr.cast(target_type=arrow_type, safe=True)
except (pa.ArrowInvalid, pa.ArrowTypeError):
raise PySparkRuntimeError(
errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",
messageParameters={
"expected": str(arrow_type),
"actual": str(arr.type),
},
)

def dump_stream(self, iterator, stream):
"""
Override to handle type coercion for ArrowUDTF outputs.
ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples.
"""
import pyarrow as pa

def apply_type_coercion():
for batch, arrow_return_type in iterator:
assert isinstance(
arrow_return_type, pa.StructType
), f"Expected pa.StructType, got {type(arrow_return_type)}"

# Handle empty struct case specially
if batch.num_columns == 0:
coerced_batch = batch # skip type coercion
else:
expected_field_names = arrow_return_type.names
actual_field_names = batch.schema.names

if expected_field_names != actual_field_names:
raise PySparkTypeError(
"Target schema's field names are not matching the record batch's "
"field names. "
f"Expected: {expected_field_names}, but got: {actual_field_names}."
)
Comment on lines +270 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm we didn't check this in verify_result? It's a much better error message than before but we should take a look on how to merge this with verify_result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))).

verify_arrow_result(
    pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
    assign_cols_by_name=False,
    expected_cols_and_types=[
        (col.name, to_arrow_type(col.dataType)) for col in return_type.fields
    ],
)

Copy link
Contributor Author

@shujingyang-db shujingyang-db Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I refactor verify_arrow_result in a follow-up PR and integrate the results verification to unblock this PR?

The current verify_arrow_result is primarily for Arrow UDFs


coerced_arrays = []
for i, field in enumerate(arrow_return_type):
original_array = batch.column(i)
coerced_array = self._create_array(original_array, field.type)
coerced_arrays.append(coerced_array)
Comment on lines +282 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we directly use batch.cast(arrow_return_type)? and the default safe parameter should be True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this with @ueshin offline. Unfortunately, RecordBatch.cast isn’t available in the currently required minimal PyArrow version. We’ll need to bump the minimum requirement to support it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it?

coerced_batch = pa.RecordBatch.from_arrays(
coerced_arrays, names=arrow_return_type.names
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expected_field_names or actual_field_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are same here :)

)
yield coerced_batch, arrow_return_type

return super().dump_stream(apply_type_coercion(), stream)


class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
"""
Expand Down
104 changes: 98 additions & 6 deletions python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def eval(self) -> Iterator["pa.Table"]:
)
yield result_table

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
with self.assertRaisesRegex(
PythonException,
"Target schema's field names are not matching the record batch's field names",
):
result_df = MismatchedSchemaUDTF()
result_df.collect()

Expand Down Expand Up @@ -330,9 +333,10 @@ def eval(self) -> Iterator["pa.Table"]:
)
yield result_table

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
result_df = LongToIntUDTF()
result_df.collect()
# Should succeed with automatic coercion
result_df = LongToIntUDTF()
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_type_coercion_string_to_int(self):
@arrow_udtf(returnType="id int")
Expand All @@ -341,15 +345,103 @@ def eval(self) -> Iterator["pa.Table"]:
# Return string values that cannot be coerced to int
result_table = pa.table(
{
"id": pa.array(["abc", "def", "xyz"], type=pa.string()),
"id": pa.array(["1", "2", "xyz"], type=pa.string()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if it's pa.array(["1", "2", "3"]? Shall we have the test to confirm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting from "1" to 1 should work. I added a test_arrow_udtf_type_coercion_string_to_int_safe.

}
)
yield result_table

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
# Should fail with Arrow cast exception since string cannot be cast to int
with self.assertRaisesRegex(
PythonException,
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF\\] "
"Column names of the returned pyarrow.Table or pyarrow.RecordBatch do not match "
"specified schema. Expected: int32 Actual: string",
):
result_df = StringToIntUDTF()
result_df.collect()

def test_arrow_udtf_type_coercion_string_to_int_safe(self):
@arrow_udtf(returnType="id int")
class StringToIntUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"id": pa.array(["1", "2", "3"], type=pa.string()),
}
)
yield result_table

result_df = StringToIntUDTF()
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_type_corecion_int64_to_int32_safe(self):
@arrow_udtf(returnType="id int")
class Int64ToInt32UDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"id": pa.array([1, 2, 3], type=pa.int64()), # long values
}
)
yield result_table

result_df = Int64ToInt32UDTF()
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
assertDataFrameEqual(result_df, expected_df)

def test_return_type_coercion_success(self):
@arrow_udtf(returnType="value int")
class CoercionSuccessUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"value": pa.array([10, 20, 30], type=pa.int64()), # long -> int coercion
}
)
yield result_table

result_df = CoercionSuccessUDTF()
expected_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int")
assertDataFrameEqual(result_df, expected_df)

def test_return_type_coercion_overflow(self):
@arrow_udtf(returnType="value int")
class CoercionOverflowUDTF:
def eval(self) -> Iterator["pa.Table"]:
# Return values that will cause overflow when casting long to int
result_table = pa.table(
{
"value": pa.array([2147483647 + 1], type=pa.int64()), # int32 max + 1
}
)
yield result_table

# Should fail with PyArrow overflow exception
with self.assertRaises(Exception):
result_df = CoercionOverflowUDTF()
result_df.collect()

def test_return_type_coercion_multiple_columns(self):
@arrow_udtf(returnType="id int, price float")
class MultipleColumnCoercionUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"id": pa.array([1, 2, 3], type=pa.int64()), # long -> int coercion
"price": pa.array(
[10.5, 20.7, 30.9], type=pa.float64()
), # double -> float coercion
}
)
yield result_table

result_df = MultipleColumnCoercionUDTF()
expected_df = self.spark.createDataFrame(
[(1, 10.5), (2, 20.7), (3, 30.9)], "id int, price float"
)
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_with_empty_column_result(self):
@arrow_udtf(returnType=StructType())
class EmptyResultUDTF:
Expand Down
10 changes: 2 additions & 8 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,14 +1970,8 @@ def verify_result(result):
},
)

# Verify the type and the schema of the result.
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
# We verify the type of the result and do type corerion
# in the serializer
return result

# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
Expand Down