Skip to content

Commit d13c7ce

Browse files
rebase
1 parent 2f6b58b commit d13c7ce

File tree

1 file changed

+83
-59
lines changed

1 file changed

+83
-59
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udtf.py

Lines changed: 83 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,37 +1416,38 @@ def test_arrow_udtf_table_partition_by_single_column(self):
14161416
class PartitionSumUDTF:
14171417
def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
14181418
table = pa.table(table_data)
1419-
1419+
14201420
# Each partition will have records with the same category
14211421
if table.num_rows > 0:
14221422
category = table.column("category")[0].as_py()
14231423
total = pa.compute.sum(table.column("value")).as_py()
1424-
1425-
result_table = pa.table({
1426-
"partition_key": pa.array([category], type=pa.string()),
1427-
"total_value": pa.array([total], type=pa.int64())
1428-
})
1424+
1425+
result_table = pa.table(
1426+
{
1427+
"partition_key": pa.array([category], type=pa.string()),
1428+
"total_value": pa.array([total], type=pa.int64()),
1429+
}
1430+
)
14291431
yield result_table
14301432

14311433
self.spark.udtf.register("partition_sum_udtf", PartitionSumUDTF)
1432-
1434+
14331435
# Create test data with categories
1434-
test_data = [
1435-
("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50)
1436-
]
1436+
test_data = [("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50)]
14371437
test_df = self.spark.createDataFrame(test_data, "category string, value int")
14381438
test_df.createOrReplaceTempView("partition_test_data")
14391439

1440-
1441-
result_df = self.spark.sql("""
1440+
result_df = self.spark.sql(
1441+
"""
14421442
SELECT * FROM partition_sum_udtf(
14431443
TABLE(partition_test_data) PARTITION BY category
14441444
) ORDER BY partition_key
1445-
""")
1446-
1447-
expected_df = self.spark.createDataFrame([
1448-
("A", 30), ("B", 70), ("C", 50)
1449-
], "partition_key string, total_value bigint")
1445+
"""
1446+
)
1447+
1448+
expected_df = self.spark.createDataFrame(
1449+
[("A", 30), ("B", 70), ("C", 50)], "partition_key string, total_value bigint"
1450+
)
14501451
assertDataFrameEqual(result_df, expected_df)
14511452

14521453
@unittest.skip("SPARK-53387: Support PARTIITON BY with Arrow UDTF")
@@ -1455,44 +1456,61 @@ def test_arrow_udtf_table_partition_by_multiple_columns(self):
14551456
class DeptStatusCountUDTF:
14561457
def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
14571458
table = pa.table(table_data)
1458-
1459+
14591460
if table.num_rows > 0:
14601461
dept = table.column("department")[0].as_py()
14611462
status = table.column("status")[0].as_py()
14621463
count = table.num_rows
1463-
1464-
result_table = pa.table({
1465-
"dept": pa.array([dept], type=pa.string()),
1466-
"status": pa.array([status], type=pa.string()),
1467-
"count_employees": pa.array([count], type=pa.int64())
1468-
})
1464+
1465+
result_table = pa.table(
1466+
{
1467+
"dept": pa.array([dept], type=pa.string()),
1468+
"status": pa.array([status], type=pa.string()),
1469+
"count_employees": pa.array([count], type=pa.int64()),
1470+
}
1471+
)
14691472
yield result_table
14701473

14711474
self.spark.udtf.register("dept_status_count_udtf", DeptStatusCountUDTF)
1472-
1475+
14731476
test_data = [
1474-
("IT", "active"), ("IT", "active"), ("IT", "inactive"),
1475-
("HR", "active"), ("HR", "inactive"), ("Finance", "active")
1477+
("IT", "active"),
1478+
("IT", "active"),
1479+
("IT", "inactive"),
1480+
("HR", "active"),
1481+
("HR", "inactive"),
1482+
("Finance", "active"),
14761483
]
14771484
test_df = self.spark.createDataFrame(test_data, "department string, status string")
14781485
test_df.createOrReplaceTempView("employee_data")
1479-
1480-
result_df = self.spark.sql("""
1486+
1487+
result_df = self.spark.sql(
1488+
"""
14811489
SELECT * FROM dept_status_count_udtf(
14821490
TABLE(SELECT * FROM employee_data)
14831491
PARTITION BY department, status
14841492
) ORDER BY dept, status
1485-
""")
1486-
1487-
expected_df = self.spark.createDataFrame([
1488-
("Finance", "active", 1), ("HR", "active", 1), ("HR", "inactive", 1), ("IT", "active", 2), ("IT", "inactive", 1)
1489-
], "dept string, status string, count_employees bigint")
1493+
"""
1494+
)
1495+
1496+
expected_df = self.spark.createDataFrame(
1497+
[
1498+
("Finance", "active", 1),
1499+
("HR", "active", 1),
1500+
("HR", "inactive", 1),
1501+
("IT", "active", 2),
1502+
("IT", "inactive", 1),
1503+
],
1504+
"dept string, status string, count_employees bigint",
1505+
)
14901506
assertDataFrameEqual(result_df, expected_df)
14911507

14921508
def test_arrow_udtf_with_scalar_first_table_second(self):
14931509
@arrow_udtf(returnType="filtered_id bigint")
14941510
class ScalarFirstTableSecondUDTF:
1495-
def eval(self, threshold: "pa.Array", table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
1511+
def eval(
1512+
self, threshold: "pa.Array", table_data: "pa.RecordBatch"
1513+
) -> Iterator["pa.Table"]:
14961514
assert isinstance(
14971515
threshold, pa.Array
14981516
), f"Expected pa.Array for threshold, got {type(threshold)}"
@@ -1529,13 +1547,14 @@ def eval(self, threshold: "pa.Array", table_data: "pa.RecordBatch") -> Iterator[
15291547

15301548
def test_arrow_udtf_with_table_argument_in_middle(self):
15311549
"""Test Arrow UDTF with table argument in the middle of multiple scalar arguments."""
1550+
15321551
@arrow_udtf(returnType="filtered_id bigint")
15331552
class TableInMiddleUDTF:
15341553
def eval(
1535-
self,
1536-
min_threshold: "pa.Array",
1537-
table_data: "pa.RecordBatch",
1538-
max_threshold: "pa.Array"
1554+
self,
1555+
min_threshold: "pa.Array",
1556+
table_data: "pa.RecordBatch",
1557+
max_threshold: "pa.Array",
15391558
) -> Iterator["pa.Table"]:
15401559
assert isinstance(
15411560
min_threshold, pa.Array
@@ -1553,11 +1572,11 @@ def eval(
15531572
# Convert record batch to table
15541573
table = pa.table(table_data)
15551574
id_column = table.column("id")
1556-
1575+
15571576
# Filter rows where min_val < id < max_val
15581577
mask = pa.compute.and_(
15591578
pa.compute.greater(id_column, pa.scalar(min_val)),
1560-
pa.compute.less(id_column, pa.scalar(max_val))
1579+
pa.compute.less(id_column, pa.scalar(max_val)),
15611580
)
15621581
filtered_table = table.filter(mask)
15631582

@@ -1584,9 +1603,7 @@ def test_arrow_udtf_with_named_arguments(self):
15841603
@arrow_udtf(returnType="result_id bigint, multiplier_used int")
15851604
class NamedArgsUDTF:
15861605
def eval(
1587-
self,
1588-
table_data: "pa.RecordBatch",
1589-
multiplier: "pa.Array"
1606+
self, table_data: "pa.RecordBatch", multiplier: "pa.Array"
15901607
) -> Iterator["pa.Table"]:
15911608
assert isinstance(
15921609
table_data, pa.RecordBatch
@@ -1600,54 +1617,61 @@ def eval(
16001617
# Convert record batch to table
16011618
table = pa.table(table_data)
16021619
id_column = table.column("id")
1603-
1620+
16041621
# Multiply each id by the multiplier
16051622
multiplied_ids = pa.compute.multiply(id_column, pa.scalar(multiplier_val))
1606-
1607-
result_table = pa.table({
1608-
"result_id": multiplied_ids,
1609-
"multiplier_used": pa.array([multiplier_val] * table.num_rows, type=pa.int32())
1610-
})
1623+
1624+
result_table = pa.table(
1625+
{
1626+
"result_id": multiplied_ids,
1627+
"multiplier_used": pa.array(
1628+
[multiplier_val] * table.num_rows, type=pa.int32()
1629+
),
1630+
}
1631+
)
16111632
yield result_table
16121633

16131634
# Test with DataFrame API using named arguments
16141635
# TODO(SPARK-53426): Support named table argument with DataFrame API
16151636
# input_df = self.spark.range(3) # [0, 1, 2]
16161637
# result_df = NamedArgsUDTF(table_data=input_df.asTable(), multiplier=lit(5))
16171638
expected_df = self.spark.createDataFrame(
1618-
[(0, 5), (5, 5), (10, 5)],
1619-
"result_id bigint, multiplier_used int"
1639+
[(0, 5), (5, 5), (10, 5)], "result_id bigint, multiplier_used int"
16201640
)
16211641
# assertDataFrameEqual(result_df, expected_df)
16221642

16231643
# Test with DataFrame API using different named argument order
16241644
# TODO(SPARK-53426): Support named table argument with DataFrame API
16251645
# result_df2 = NamedArgsUDTF(multiplier=lit(3), table_data=input_df.asTable())
16261646
expected_df2 = self.spark.createDataFrame(
1627-
[(0, 3), (3, 3), (6, 3)],
1628-
"result_id bigint, multiplier_used int"
1647+
[(0, 3), (3, 3), (6, 3)], "result_id bigint, multiplier_used int"
16291648
)
16301649
# assertDataFrameEqual(result_df2, expected_df2)
16311650

16321651
# Test SQL registration and usage with named arguments
16331652
self.spark.udtf.register("test_named_args_udtf", NamedArgsUDTF)
1634-
1635-
sql_result_df = self.spark.sql("""
1653+
1654+
sql_result_df = self.spark.sql(
1655+
"""
16361656
SELECT * FROM test_named_args_udtf(
16371657
table_data => TABLE(SELECT id FROM range(0, 3))
16381658
multiplier => 5
16391659
)
1640-
""")
1660+
"""
1661+
)
16411662
assertDataFrameEqual(sql_result_df, expected_df)
16421663

1643-
sql_result_df2 = self.spark.sql("""
1664+
sql_result_df2 = self.spark.sql(
1665+
"""
16441666
SELECT * FROM test_named_args_udtf(
16451667
multiplier => 3,
16461668
table_data => TABLE(SELECT id FROM range(0, 3))
16471669
)
1648-
""")
1670+
"""
1671+
)
16491672
assertDataFrameEqual(sql_result_df2, expected_df2)
16501673

1674+
16511675
class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
16521676
pass
16531677

0 commit comments

Comments
 (0)