Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 276 additions & 1 deletion python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def eval(self, input_val: int):
expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int")
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_with_named_arguments(self):
def test_arrow_udtf_with_named_arguments_scalar_only(self):
@arrow_udtf(returnType="x int, y int, sum int")
class NamedArgsUDTF:
def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]:
Expand Down Expand Up @@ -1410,6 +1410,281 @@ def eval(self, table_arg, struct_arg) -> Iterator["pa.Table"]:
assert row.table_is_batch is True
assert row.struct_is_array is True

def test_arrow_udtf_table_partition_by_single_column(self):
@arrow_udtf(returnType="partition_key string, total_value bigint")
class PartitionSumUDTF:
def __init__(self):
self._category = None
self._total = 0

def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
table = pa.table(table_data)

# Each partition will have records with the same category
if table.num_rows > 0:
self._category = table.column("category")[0].as_py()
self._total += pa.compute.sum(table.column("value")).as_py()
# Don't yield here - accumulate and yield in terminate
return iter(())

def terminate(self) -> Iterator["pa.Table"]:
if self._category is not None:
result_table = pa.table(
{
"partition_key": pa.array([self._category], type=pa.string()),
"total_value": pa.array([self._total], type=pa.int64()),
}
)
yield result_table

self.spark.udtf.register("partition_sum_udtf", PartitionSumUDTF)

# Create test data with categories
test_data = [("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50)]
test_df = self.spark.createDataFrame(test_data, "category string, value int")
test_df.createOrReplaceTempView("partition_test_data")

result_df = self.spark.sql(
"""
SELECT * FROM partition_sum_udtf(
TABLE(partition_test_data) PARTITION BY category
) ORDER BY partition_key
"""
)
Comment on lines +1447 to +1453
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also potentially flaky as same as tests in the previous PR. Use terminate to be more stable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank for pointing this out. Fixed.


expected_df = self.spark.createDataFrame(
[("A", 30), ("B", 70), ("C", 50)], "partition_key string, total_value bigint"
)
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_table_partition_by_multiple_columns(self):
@arrow_udtf(returnType="dept string, status string, count_employees bigint")
class DeptStatusCountUDTF:
def __init__(self):
self._dept = None
self._status = None
self._count = 0

def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
table = pa.table(table_data)

if table.num_rows > 0:
self._dept = table.column("department")[0].as_py()
self._status = table.column("status")[0].as_py()
self._count += table.num_rows
# Don't yield here - accumulate and yield in terminate
return iter(())

def terminate(self) -> Iterator["pa.Table"]:
if self._dept is not None and self._status is not None:
result_table = pa.table(
{
"dept": pa.array([self._dept], type=pa.string()),
"status": pa.array([self._status], type=pa.string()),
"count_employees": pa.array([self._count], type=pa.int64()),
}
)
yield result_table

self.spark.udtf.register("dept_status_count_udtf", DeptStatusCountUDTF)

test_data = [
("IT", "active"),
("IT", "active"),
("IT", "inactive"),
("HR", "active"),
("HR", "inactive"),
("Finance", "active"),
]
test_df = self.spark.createDataFrame(test_data, "department string, status string")
test_df.createOrReplaceTempView("employee_data")

result_df = self.spark.sql(
"""
SELECT * FROM dept_status_count_udtf(
TABLE(SELECT * FROM employee_data)
PARTITION BY (department, status)
) ORDER BY dept, status
"""
)

expected_df = self.spark.createDataFrame(
[
("Finance", "active", 1),
("HR", "active", 1),
("HR", "inactive", 1),
("IT", "active", 2),
("IT", "inactive", 1),
],
"dept string, status string, count_employees bigint",
)
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_with_scalar_first_table_second(self):
@arrow_udtf(returnType="filtered_id bigint")
class ScalarFirstTableSecondUDTF:
def eval(
self, threshold: "pa.Array", table_data: "pa.RecordBatch"
) -> Iterator["pa.Table"]:
assert isinstance(
threshold, pa.Array
), f"Expected pa.Array for threshold, got {type(threshold)}"
assert isinstance(
table_data, pa.RecordBatch
), f"Expected pa.RecordBatch for table_data, got {type(table_data)}"

threshold_val = threshold[0].as_py()

# Convert record batch to table
table = pa.table(table_data)
id_column = table.column("id")
mask = pa.compute.greater(id_column, pa.scalar(threshold_val))
filtered_table = table.filter(mask)

if filtered_table.num_rows > 0:
result_table = pa.table(
{"filtered_id": filtered_table.column("id")} # Keep original type
)
yield result_table

# Test with DataFrame API - scalar first, table second
input_df = self.spark.range(8)
result_df = ScalarFirstTableSecondUDTF(lit(4), input_df.asTable())
expected_df = self.spark.createDataFrame([(5,), (6,), (7,)], "filtered_id bigint")
assertDataFrameEqual(result_df, expected_df)

# Test SQL registration and usage
self.spark.udtf.register("test_scalar_first_table_second_udtf", ScalarFirstTableSecondUDTF)
sql_result_df = self.spark.sql(
"SELECT * FROM test_scalar_first_table_second_udtf("
"4, TABLE(SELECT id FROM range(0, 8)))"
)
assertDataFrameEqual(sql_result_df, expected_df)

def test_arrow_udtf_with_table_argument_in_middle(self):
"""Test Arrow UDTF with table argument in the middle of multiple scalar arguments."""

@arrow_udtf(returnType="filtered_id bigint")
class TableInMiddleUDTF:
def eval(
self,
min_threshold: "pa.Array",
table_data: "pa.RecordBatch",
max_threshold: "pa.Array",
) -> Iterator["pa.Table"]:
assert isinstance(
min_threshold, pa.Array
), f"Expected pa.Array for min_threshold, got {type(min_threshold)}"
assert isinstance(
table_data, pa.RecordBatch
), f"Expected pa.RecordBatch for table_data, got {type(table_data)}"
assert isinstance(
max_threshold, pa.Array
), f"Expected pa.Array for max_threshold, got {type(max_threshold)}"

min_val = min_threshold[0].as_py()
max_val = max_threshold[0].as_py()

# Convert record batch to table
table = pa.table(table_data)
id_column = table.column("id")

# Filter rows where min_val < id < max_val
mask = pa.compute.and_(
pa.compute.greater(id_column, pa.scalar(min_val)),
pa.compute.less(id_column, pa.scalar(max_val)),
)
filtered_table = table.filter(mask)

if filtered_table.num_rows > 0:
result_table = pa.table(
{"filtered_id": filtered_table.column("id")} # Keep original type
)
yield result_table

# Test with DataFrame API - scalar, table, scalar
input_df = self.spark.range(10)
result_df = TableInMiddleUDTF(lit(2), input_df.asTable(), lit(7))
expected_df = self.spark.createDataFrame([(3,), (4,), (5,), (6,)], "filtered_id bigint")
assertDataFrameEqual(result_df, expected_df)

# Test SQL registration and usage
self.spark.udtf.register("test_table_in_middle_udtf", TableInMiddleUDTF)
sql_result_df = self.spark.sql(
"SELECT * FROM test_table_in_middle_udtf(2, TABLE(SELECT id FROM range(0, 10)), 7)"
)
assertDataFrameEqual(sql_result_df, expected_df)

def test_arrow_udtf_with_named_arguments(self):
@arrow_udtf(returnType="result_id bigint, multiplier_used int")
class NamedArgsUDTF:
def eval(
self, table_data: "pa.RecordBatch", multiplier: "pa.Array"
) -> Iterator["pa.Table"]:
assert isinstance(
table_data, pa.RecordBatch
), f"Expected pa.RecordBatch for table_data, got {type(table_data)}"
assert isinstance(
multiplier, pa.Array
), f"Expected pa.Array for multiplier, got {type(multiplier)}"

multiplier_val = multiplier[0].as_py()

# Convert record batch to table
table = pa.table(table_data)
id_column = table.column("id")

# Multiply each id by the multiplier
multiplied_ids = pa.compute.multiply(id_column, pa.scalar(multiplier_val))

result_table = pa.table(
{
"result_id": multiplied_ids,
"multiplier_used": pa.array(
[multiplier_val] * table.num_rows, type=pa.int32()
),
}
)
yield result_table

# Test with DataFrame API using named arguments
input_df = self.spark.range(3) # [0, 1, 2]
result_df = NamedArgsUDTF(table_data=input_df.asTable(), multiplier=lit(5))
expected_df = self.spark.createDataFrame(
[(0, 5), (5, 5), (10, 5)], "result_id bigint, multiplier_used int"
)
assertDataFrameEqual(result_df, expected_df)

# Test with DataFrame API using different named argument order
result_df2 = NamedArgsUDTF(multiplier=lit(3), table_data=input_df.asTable())
expected_df2 = self.spark.createDataFrame(
[(0, 3), (3, 3), (6, 3)], "result_id bigint, multiplier_used int"
)
assertDataFrameEqual(result_df2, expected_df2)

# Test SQL registration and usage with named arguments
self.spark.udtf.register("test_named_args_udtf", NamedArgsUDTF)

sql_result_df = self.spark.sql(
"""
SELECT * FROM test_named_args_udtf(
table_data => TABLE(SELECT id FROM range(0, 3)),
multiplier => 5
)
"""
)
assertDataFrameEqual(sql_result_df, expected_df)

sql_result_df2 = self.spark.sql(
"""
SELECT * FROM test_named_args_udtf(
multiplier => 3,
table_data => TABLE(SELECT id FROM range(0, 3))
)
"""
)
assertDataFrameEqual(sql_result_df2, expected_df2)


class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
pass
Expand Down