Skip to content

Commit fa9e787

Browse files
[SPARK-53425][PYTHON][TESTS] Add more table argument tests for Arrow Python UDTFs
### What changes were proposed in this pull request? This PR adds more tests for various table argument support for Arrow Python UDTFs. It also exposed some existing issues that need to be fixed: - [x] [SPARK-53387](https://issues.apache.org/jira/browse/SPARK-53387): Support PARTITION BY clause with Python Arrow UDTF - [x] [SPARK-53426](https://issues.apache.org/jira/browse/SPARK-53426): Support named table argument with asTable() API ### Why are the changes needed? To improve test coverage ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? Yes Closes #52170 from allisonwang-db/spark-53425-tbl-arg-tests. Authored-by: Allison Wang <[email protected]> Signed-off-by: Allison Wang <[email protected]>
1 parent 2a9999f commit fa9e787

File tree

1 file changed

+276
-1
lines changed

1 file changed

+276
-1
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udtf.py

Lines changed: 276 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def eval(self, input_val: int):
699699
expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int")
700700
assertDataFrameEqual(result_df, expected_df)
701701

702-
def test_arrow_udtf_with_named_arguments(self):
702+
def test_arrow_udtf_with_named_arguments_scalar_only(self):
703703
@arrow_udtf(returnType="x int, y int, sum int")
704704
class NamedArgsUDTF:
705705
def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]:
@@ -1410,6 +1410,281 @@ def eval(self, table_arg, struct_arg) -> Iterator["pa.Table"]:
14101410
assert row.table_is_batch is True
14111411
assert row.struct_is_array is True
14121412

1413+
def test_arrow_udtf_table_partition_by_single_column(self):
1414+
@arrow_udtf(returnType="partition_key string, total_value bigint")
1415+
class PartitionSumUDTF:
1416+
def __init__(self):
1417+
self._category = None
1418+
self._total = 0
1419+
1420+
def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
1421+
table = pa.table(table_data)
1422+
1423+
# Each partition will have records with the same category
1424+
if table.num_rows > 0:
1425+
self._category = table.column("category")[0].as_py()
1426+
self._total += pa.compute.sum(table.column("value")).as_py()
1427+
# Don't yield here - accumulate and yield in terminate
1428+
return iter(())
1429+
1430+
def terminate(self) -> Iterator["pa.Table"]:
1431+
if self._category is not None:
1432+
result_table = pa.table(
1433+
{
1434+
"partition_key": pa.array([self._category], type=pa.string()),
1435+
"total_value": pa.array([self._total], type=pa.int64()),
1436+
}
1437+
)
1438+
yield result_table
1439+
1440+
self.spark.udtf.register("partition_sum_udtf", PartitionSumUDTF)
1441+
1442+
# Create test data with categories
1443+
test_data = [("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50)]
1444+
test_df = self.spark.createDataFrame(test_data, "category string, value int")
1445+
test_df.createOrReplaceTempView("partition_test_data")
1446+
1447+
result_df = self.spark.sql(
1448+
"""
1449+
SELECT * FROM partition_sum_udtf(
1450+
TABLE(partition_test_data) PARTITION BY category
1451+
) ORDER BY partition_key
1452+
"""
1453+
)
1454+
1455+
expected_df = self.spark.createDataFrame(
1456+
[("A", 30), ("B", 70), ("C", 50)], "partition_key string, total_value bigint"
1457+
)
1458+
assertDataFrameEqual(result_df, expected_df)
1459+
1460+
def test_arrow_udtf_table_partition_by_multiple_columns(self):
1461+
@arrow_udtf(returnType="dept string, status string, count_employees bigint")
1462+
class DeptStatusCountUDTF:
1463+
def __init__(self):
1464+
self._dept = None
1465+
self._status = None
1466+
self._count = 0
1467+
1468+
def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
1469+
table = pa.table(table_data)
1470+
1471+
if table.num_rows > 0:
1472+
self._dept = table.column("department")[0].as_py()
1473+
self._status = table.column("status")[0].as_py()
1474+
self._count += table.num_rows
1475+
# Don't yield here - accumulate and yield in terminate
1476+
return iter(())
1477+
1478+
def terminate(self) -> Iterator["pa.Table"]:
1479+
if self._dept is not None and self._status is not None:
1480+
result_table = pa.table(
1481+
{
1482+
"dept": pa.array([self._dept], type=pa.string()),
1483+
"status": pa.array([self._status], type=pa.string()),
1484+
"count_employees": pa.array([self._count], type=pa.int64()),
1485+
}
1486+
)
1487+
yield result_table
1488+
1489+
self.spark.udtf.register("dept_status_count_udtf", DeptStatusCountUDTF)
1490+
1491+
test_data = [
1492+
("IT", "active"),
1493+
("IT", "active"),
1494+
("IT", "inactive"),
1495+
("HR", "active"),
1496+
("HR", "inactive"),
1497+
("Finance", "active"),
1498+
]
1499+
test_df = self.spark.createDataFrame(test_data, "department string, status string")
1500+
test_df.createOrReplaceTempView("employee_data")
1501+
1502+
result_df = self.spark.sql(
1503+
"""
1504+
SELECT * FROM dept_status_count_udtf(
1505+
TABLE(SELECT * FROM employee_data)
1506+
PARTITION BY (department, status)
1507+
) ORDER BY dept, status
1508+
"""
1509+
)
1510+
1511+
expected_df = self.spark.createDataFrame(
1512+
[
1513+
("Finance", "active", 1),
1514+
("HR", "active", 1),
1515+
("HR", "inactive", 1),
1516+
("IT", "active", 2),
1517+
("IT", "inactive", 1),
1518+
],
1519+
"dept string, status string, count_employees bigint",
1520+
)
1521+
assertDataFrameEqual(result_df, expected_df)
1522+
1523+
def test_arrow_udtf_with_scalar_first_table_second(self):
1524+
@arrow_udtf(returnType="filtered_id bigint")
1525+
class ScalarFirstTableSecondUDTF:
1526+
def eval(
1527+
self, threshold: "pa.Array", table_data: "pa.RecordBatch"
1528+
) -> Iterator["pa.Table"]:
1529+
assert isinstance(
1530+
threshold, pa.Array
1531+
), f"Expected pa.Array for threshold, got {type(threshold)}"
1532+
assert isinstance(
1533+
table_data, pa.RecordBatch
1534+
), f"Expected pa.RecordBatch for table_data, got {type(table_data)}"
1535+
1536+
threshold_val = threshold[0].as_py()
1537+
1538+
# Convert record batch to table
1539+
table = pa.table(table_data)
1540+
id_column = table.column("id")
1541+
mask = pa.compute.greater(id_column, pa.scalar(threshold_val))
1542+
filtered_table = table.filter(mask)
1543+
1544+
if filtered_table.num_rows > 0:
1545+
result_table = pa.table(
1546+
{"filtered_id": filtered_table.column("id")} # Keep original type
1547+
)
1548+
yield result_table
1549+
1550+
# Test with DataFrame API - scalar first, table second
1551+
input_df = self.spark.range(8)
1552+
result_df = ScalarFirstTableSecondUDTF(lit(4), input_df.asTable())
1553+
expected_df = self.spark.createDataFrame([(5,), (6,), (7,)], "filtered_id bigint")
1554+
assertDataFrameEqual(result_df, expected_df)
1555+
1556+
# Test SQL registration and usage
1557+
self.spark.udtf.register("test_scalar_first_table_second_udtf", ScalarFirstTableSecondUDTF)
1558+
sql_result_df = self.spark.sql(
1559+
"SELECT * FROM test_scalar_first_table_second_udtf("
1560+
"4, TABLE(SELECT id FROM range(0, 8)))"
1561+
)
1562+
assertDataFrameEqual(sql_result_df, expected_df)
1563+
1564+
def test_arrow_udtf_with_table_argument_in_middle(self):
1565+
"""Test Arrow UDTF with table argument in the middle of multiple scalar arguments."""
1566+
1567+
@arrow_udtf(returnType="filtered_id bigint")
1568+
class TableInMiddleUDTF:
1569+
def eval(
1570+
self,
1571+
min_threshold: "pa.Array",
1572+
table_data: "pa.RecordBatch",
1573+
max_threshold: "pa.Array",
1574+
) -> Iterator["pa.Table"]:
1575+
assert isinstance(
1576+
min_threshold, pa.Array
1577+
), f"Expected pa.Array for min_threshold, got {type(min_threshold)}"
1578+
assert isinstance(
1579+
table_data, pa.RecordBatch
1580+
), f"Expected pa.RecordBatch for table_data, got {type(table_data)}"
1581+
assert isinstance(
1582+
max_threshold, pa.Array
1583+
), f"Expected pa.Array for max_threshold, got {type(max_threshold)}"
1584+
1585+
min_val = min_threshold[0].as_py()
1586+
max_val = max_threshold[0].as_py()
1587+
1588+
# Convert record batch to table
1589+
table = pa.table(table_data)
1590+
id_column = table.column("id")
1591+
1592+
# Filter rows where min_val < id < max_val
1593+
mask = pa.compute.and_(
1594+
pa.compute.greater(id_column, pa.scalar(min_val)),
1595+
pa.compute.less(id_column, pa.scalar(max_val)),
1596+
)
1597+
filtered_table = table.filter(mask)
1598+
1599+
if filtered_table.num_rows > 0:
1600+
result_table = pa.table(
1601+
{"filtered_id": filtered_table.column("id")} # Keep original type
1602+
)
1603+
yield result_table
1604+
1605+
# Test with DataFrame API - scalar, table, scalar
1606+
input_df = self.spark.range(10)
1607+
result_df = TableInMiddleUDTF(lit(2), input_df.asTable(), lit(7))
1608+
expected_df = self.spark.createDataFrame([(3,), (4,), (5,), (6,)], "filtered_id bigint")
1609+
assertDataFrameEqual(result_df, expected_df)
1610+
1611+
# Test SQL registration and usage
1612+
self.spark.udtf.register("test_table_in_middle_udtf", TableInMiddleUDTF)
1613+
sql_result_df = self.spark.sql(
1614+
"SELECT * FROM test_table_in_middle_udtf(2, TABLE(SELECT id FROM range(0, 10)), 7)"
1615+
)
1616+
assertDataFrameEqual(sql_result_df, expected_df)
1617+
1618+
def test_arrow_udtf_with_named_arguments(self):
1619+
@arrow_udtf(returnType="result_id bigint, multiplier_used int")
1620+
class NamedArgsUDTF:
1621+
def eval(
1622+
self, table_data: "pa.RecordBatch", multiplier: "pa.Array"
1623+
) -> Iterator["pa.Table"]:
1624+
assert isinstance(
1625+
table_data, pa.RecordBatch
1626+
), f"Expected pa.RecordBatch for table_data, got {type(table_data)}"
1627+
assert isinstance(
1628+
multiplier, pa.Array
1629+
), f"Expected pa.Array for multiplier, got {type(multiplier)}"
1630+
1631+
multiplier_val = multiplier[0].as_py()
1632+
1633+
# Convert record batch to table
1634+
table = pa.table(table_data)
1635+
id_column = table.column("id")
1636+
1637+
# Multiply each id by the multiplier
1638+
multiplied_ids = pa.compute.multiply(id_column, pa.scalar(multiplier_val))
1639+
1640+
result_table = pa.table(
1641+
{
1642+
"result_id": multiplied_ids,
1643+
"multiplier_used": pa.array(
1644+
[multiplier_val] * table.num_rows, type=pa.int32()
1645+
),
1646+
}
1647+
)
1648+
yield result_table
1649+
1650+
# Test with DataFrame API using named arguments
1651+
input_df = self.spark.range(3) # [0, 1, 2]
1652+
result_df = NamedArgsUDTF(table_data=input_df.asTable(), multiplier=lit(5))
1653+
expected_df = self.spark.createDataFrame(
1654+
[(0, 5), (5, 5), (10, 5)], "result_id bigint, multiplier_used int"
1655+
)
1656+
assertDataFrameEqual(result_df, expected_df)
1657+
1658+
# Test with DataFrame API using different named argument order
1659+
result_df2 = NamedArgsUDTF(multiplier=lit(3), table_data=input_df.asTable())
1660+
expected_df2 = self.spark.createDataFrame(
1661+
[(0, 3), (3, 3), (6, 3)], "result_id bigint, multiplier_used int"
1662+
)
1663+
assertDataFrameEqual(result_df2, expected_df2)
1664+
1665+
# Test SQL registration and usage with named arguments
1666+
self.spark.udtf.register("test_named_args_udtf", NamedArgsUDTF)
1667+
1668+
sql_result_df = self.spark.sql(
1669+
"""
1670+
SELECT * FROM test_named_args_udtf(
1671+
table_data => TABLE(SELECT id FROM range(0, 3)),
1672+
multiplier => 5
1673+
)
1674+
"""
1675+
)
1676+
assertDataFrameEqual(sql_result_df, expected_df)
1677+
1678+
sql_result_df2 = self.spark.sql(
1679+
"""
1680+
SELECT * FROM test_named_args_udtf(
1681+
multiplier => 3,
1682+
table_data => TABLE(SELECT id FROM range(0, 3))
1683+
)
1684+
"""
1685+
)
1686+
assertDataFrameEqual(sql_result_df2, expected_df2)
1687+
14131688

14141689
class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
14151690
pass

0 commit comments

Comments
 (0)