@@ -1416,37 +1416,38 @@ def test_arrow_udtf_table_partition_by_single_column(self):
1416
1416
class PartitionSumUDTF :
1417
1417
def eval (self , table_data : "pa.RecordBatch" ) -> Iterator ["pa.Table" ]:
1418
1418
table = pa .table (table_data )
1419
-
1419
+
1420
1420
# Each partition will have records with the same category
1421
1421
if table .num_rows > 0 :
1422
1422
category = table .column ("category" )[0 ].as_py ()
1423
1423
total = pa .compute .sum (table .column ("value" )).as_py ()
1424
-
1425
- result_table = pa .table ({
1426
- "partition_key" : pa .array ([category ], type = pa .string ()),
1427
- "total_value" : pa .array ([total ], type = pa .int64 ())
1428
- })
1424
+
1425
+ result_table = pa .table (
1426
+ {
1427
+ "partition_key" : pa .array ([category ], type = pa .string ()),
1428
+ "total_value" : pa .array ([total ], type = pa .int64 ()),
1429
+ }
1430
+ )
1429
1431
yield result_table
1430
1432
1431
1433
self .spark .udtf .register ("partition_sum_udtf" , PartitionSumUDTF )
1432
-
1434
+
1433
1435
# Create test data with categories
1434
- test_data = [
1435
- ("A" , 10 ), ("A" , 20 ), ("B" , 30 ), ("B" , 40 ), ("C" , 50 )
1436
- ]
1436
+ test_data = [("A" , 10 ), ("A" , 20 ), ("B" , 30 ), ("B" , 40 ), ("C" , 50 )]
1437
1437
test_df = self .spark .createDataFrame (test_data , "category string, value int" )
1438
1438
test_df .createOrReplaceTempView ("partition_test_data" )
1439
1439
1440
-
1441
- result_df = self . spark . sql ( """
1440
+ result_df = self . spark . sql (
1441
+ """
1442
1442
SELECT * FROM partition_sum_udtf(
1443
1443
TABLE(partition_test_data) PARTITION BY category
1444
1444
) ORDER BY partition_key
1445
- """ )
1446
-
1447
- expected_df = self .spark .createDataFrame ([
1448
- ("A" , 30 ), ("B" , 70 ), ("C" , 50 )
1449
- ], "partition_key string, total_value bigint" )
1445
+ """
1446
+ )
1447
+
1448
+ expected_df = self .spark .createDataFrame (
1449
+ [("A" , 30 ), ("B" , 70 ), ("C" , 50 )], "partition_key string, total_value bigint"
1450
+ )
1450
1451
assertDataFrameEqual (result_df , expected_df )
1451
1452
1452
1453
@unittest .skip ("SPARK-53387: Support PARTIITON BY with Arrow UDTF" )
@@ -1455,44 +1456,61 @@ def test_arrow_udtf_table_partition_by_multiple_columns(self):
1455
1456
class DeptStatusCountUDTF :
1456
1457
def eval (self , table_data : "pa.RecordBatch" ) -> Iterator ["pa.Table" ]:
1457
1458
table = pa .table (table_data )
1458
-
1459
+
1459
1460
if table .num_rows > 0 :
1460
1461
dept = table .column ("department" )[0 ].as_py ()
1461
1462
status = table .column ("status" )[0 ].as_py ()
1462
1463
count = table .num_rows
1463
-
1464
- result_table = pa .table ({
1465
- "dept" : pa .array ([dept ], type = pa .string ()),
1466
- "status" : pa .array ([status ], type = pa .string ()),
1467
- "count_employees" : pa .array ([count ], type = pa .int64 ())
1468
- })
1464
+
1465
+ result_table = pa .table (
1466
+ {
1467
+ "dept" : pa .array ([dept ], type = pa .string ()),
1468
+ "status" : pa .array ([status ], type = pa .string ()),
1469
+ "count_employees" : pa .array ([count ], type = pa .int64 ()),
1470
+ }
1471
+ )
1469
1472
yield result_table
1470
1473
1471
1474
self .spark .udtf .register ("dept_status_count_udtf" , DeptStatusCountUDTF )
1472
-
1475
+
1473
1476
test_data = [
1474
- ("IT" , "active" ), ("IT" , "active" ), ("IT" , "inactive" ),
1475
- ("HR" , "active" ), ("HR" , "inactive" ), ("Finance" , "active" )
1477
+ ("IT" , "active" ),
1478
+ ("IT" , "active" ),
1479
+ ("IT" , "inactive" ),
1480
+ ("HR" , "active" ),
1481
+ ("HR" , "inactive" ),
1482
+ ("Finance" , "active" ),
1476
1483
]
1477
1484
test_df = self .spark .createDataFrame (test_data , "department string, status string" )
1478
1485
test_df .createOrReplaceTempView ("employee_data" )
1479
-
1480
- result_df = self .spark .sql ("""
1486
+
1487
+ result_df = self .spark .sql (
1488
+ """
1481
1489
SELECT * FROM dept_status_count_udtf(
1482
1490
TABLE(SELECT * FROM employee_data)
1483
1491
PARTITION BY department, status
1484
1492
) ORDER BY dept, status
1485
- """ )
1486
-
1487
- expected_df = self .spark .createDataFrame ([
1488
- ("Finance" , "active" , 1 ), ("HR" , "active" , 1 ), ("HR" , "inactive" , 1 ), ("IT" , "active" , 2 ), ("IT" , "inactive" , 1 )
1489
- ], "dept string, status string, count_employees bigint" )
1493
+ """
1494
+ )
1495
+
1496
+ expected_df = self .spark .createDataFrame (
1497
+ [
1498
+ ("Finance" , "active" , 1 ),
1499
+ ("HR" , "active" , 1 ),
1500
+ ("HR" , "inactive" , 1 ),
1501
+ ("IT" , "active" , 2 ),
1502
+ ("IT" , "inactive" , 1 ),
1503
+ ],
1504
+ "dept string, status string, count_employees bigint" ,
1505
+ )
1490
1506
assertDataFrameEqual (result_df , expected_df )
1491
1507
1492
1508
def test_arrow_udtf_with_scalar_first_table_second (self ):
1493
1509
@arrow_udtf (returnType = "filtered_id bigint" )
1494
1510
class ScalarFirstTableSecondUDTF :
1495
- def eval (self , threshold : "pa.Array" , table_data : "pa.RecordBatch" ) -> Iterator ["pa.Table" ]:
1511
+ def eval (
1512
+ self , threshold : "pa.Array" , table_data : "pa.RecordBatch"
1513
+ ) -> Iterator ["pa.Table" ]:
1496
1514
assert isinstance (
1497
1515
threshold , pa .Array
1498
1516
), f"Expected pa.Array for threshold, got { type (threshold )} "
@@ -1529,13 +1547,14 @@ def eval(self, threshold: "pa.Array", table_data: "pa.RecordBatch") -> Iterator[
1529
1547
1530
1548
def test_arrow_udtf_with_table_argument_in_middle (self ):
1531
1549
"""Test Arrow UDTF with table argument in the middle of multiple scalar arguments."""
1550
+
1532
1551
@arrow_udtf (returnType = "filtered_id bigint" )
1533
1552
class TableInMiddleUDTF :
1534
1553
def eval (
1535
- self ,
1536
- min_threshold : "pa.Array" ,
1537
- table_data : "pa.RecordBatch" ,
1538
- max_threshold : "pa.Array"
1554
+ self ,
1555
+ min_threshold : "pa.Array" ,
1556
+ table_data : "pa.RecordBatch" ,
1557
+ max_threshold : "pa.Array" ,
1539
1558
) -> Iterator ["pa.Table" ]:
1540
1559
assert isinstance (
1541
1560
min_threshold , pa .Array
@@ -1553,11 +1572,11 @@ def eval(
1553
1572
# Convert record batch to table
1554
1573
table = pa .table (table_data )
1555
1574
id_column = table .column ("id" )
1556
-
1575
+
1557
1576
# Filter rows where min_val < id < max_val
1558
1577
mask = pa .compute .and_ (
1559
1578
pa .compute .greater (id_column , pa .scalar (min_val )),
1560
- pa .compute .less (id_column , pa .scalar (max_val ))
1579
+ pa .compute .less (id_column , pa .scalar (max_val )),
1561
1580
)
1562
1581
filtered_table = table .filter (mask )
1563
1582
@@ -1584,9 +1603,7 @@ def test_arrow_udtf_with_named_arguments(self):
1584
1603
@arrow_udtf (returnType = "result_id bigint, multiplier_used int" )
1585
1604
class NamedArgsUDTF :
1586
1605
def eval (
1587
- self ,
1588
- table_data : "pa.RecordBatch" ,
1589
- multiplier : "pa.Array"
1606
+ self , table_data : "pa.RecordBatch" , multiplier : "pa.Array"
1590
1607
) -> Iterator ["pa.Table" ]:
1591
1608
assert isinstance (
1592
1609
table_data , pa .RecordBatch
@@ -1600,54 +1617,61 @@ def eval(
1600
1617
# Convert record batch to table
1601
1618
table = pa .table (table_data )
1602
1619
id_column = table .column ("id" )
1603
-
1620
+
1604
1621
# Multiply each id by the multiplier
1605
1622
multiplied_ids = pa .compute .multiply (id_column , pa .scalar (multiplier_val ))
1606
-
1607
- result_table = pa .table ({
1608
- "result_id" : multiplied_ids ,
1609
- "multiplier_used" : pa .array ([multiplier_val ] * table .num_rows , type = pa .int32 ())
1610
- })
1623
+
1624
+ result_table = pa .table (
1625
+ {
1626
+ "result_id" : multiplied_ids ,
1627
+ "multiplier_used" : pa .array (
1628
+ [multiplier_val ] * table .num_rows , type = pa .int32 ()
1629
+ ),
1630
+ }
1631
+ )
1611
1632
yield result_table
1612
1633
1613
1634
# Test with DataFrame API using named arguments
1614
1635
# TODO(SPARK-53426): Support named table argument with DataFrame API
1615
1636
# input_df = self.spark.range(3) # [0, 1, 2]
1616
1637
# result_df = NamedArgsUDTF(table_data=input_df.asTable(), multiplier=lit(5))
1617
1638
expected_df = self .spark .createDataFrame (
1618
- [(0 , 5 ), (5 , 5 ), (10 , 5 )],
1619
- "result_id bigint, multiplier_used int"
1639
+ [(0 , 5 ), (5 , 5 ), (10 , 5 )], "result_id bigint, multiplier_used int"
1620
1640
)
1621
1641
# assertDataFrameEqual(result_df, expected_df)
1622
1642
1623
1643
# Test with DataFrame API using different named argument order
1624
1644
# TODO(SPARK-53426): Support named table argument with DataFrame API
1625
1645
# result_df2 = NamedArgsUDTF(multiplier=lit(3), table_data=input_df.asTable())
1626
1646
expected_df2 = self .spark .createDataFrame (
1627
- [(0 , 3 ), (3 , 3 ), (6 , 3 )],
1628
- "result_id bigint, multiplier_used int"
1647
+ [(0 , 3 ), (3 , 3 ), (6 , 3 )], "result_id bigint, multiplier_used int"
1629
1648
)
1630
1649
# assertDataFrameEqual(result_df2, expected_df2)
1631
1650
1632
1651
# Test SQL registration and usage with named arguments
1633
1652
self .spark .udtf .register ("test_named_args_udtf" , NamedArgsUDTF )
1634
-
1635
- sql_result_df = self .spark .sql ("""
1653
+
1654
+ sql_result_df = self .spark .sql (
1655
+ """
1636
1656
SELECT * FROM test_named_args_udtf(
1637
1657
table_data => TABLE(SELECT id FROM range(0, 3))
1638
1658
multiplier => 5
1639
1659
)
1640
- """ )
1660
+ """
1661
+ )
1641
1662
assertDataFrameEqual (sql_result_df , expected_df )
1642
1663
1643
- sql_result_df2 = self .spark .sql ("""
1664
+ sql_result_df2 = self .spark .sql (
1665
+ """
1644
1666
SELECT * FROM test_named_args_udtf(
1645
1667
multiplier => 3,
1646
1668
table_data => TABLE(SELECT id FROM range(0, 3))
1647
1669
)
1648
- """ )
1670
+ """
1671
+ )
1649
1672
assertDataFrameEqual (sql_result_df2 , expected_df2 )
1650
1673
1674
+
1651
1675
class ArrowUDTFTests (ArrowUDTFTestsMixin , ReusedSQLTestCase ):
1652
1676
pass
1653
1677
0 commit comments