@@ -699,7 +699,7 @@ def eval(self, input_val: int):
699
699
expected_df = self .spark .createDataFrame ([(60 , 180 )], "computed_value int, multiplied int" )
700
700
assertDataFrameEqual (result_df , expected_df )
701
701
702
- def test_arrow_udtf_with_named_arguments (self ):
702
+ def test_arrow_udtf_with_named_arguments_scalar_only (self ):
703
703
@arrow_udtf (returnType = "x int, y int, sum int" )
704
704
class NamedArgsUDTF :
705
705
def eval (self , x : "pa.Array" , y : "pa.Array" ) -> Iterator ["pa.Table" ]:
@@ -1410,6 +1410,281 @@ def eval(self, table_arg, struct_arg) -> Iterator["pa.Table"]:
1410
1410
assert row .table_is_batch is True
1411
1411
assert row .struct_is_array is True
1412
1412
1413
+ def test_arrow_udtf_table_partition_by_single_column (self ):
1414
+ @arrow_udtf (returnType = "partition_key string, total_value bigint" )
1415
+ class PartitionSumUDTF :
1416
+ def __init__ (self ):
1417
+ self ._category = None
1418
+ self ._total = 0
1419
+
1420
+ def eval (self , table_data : "pa.RecordBatch" ) -> Iterator ["pa.Table" ]:
1421
+ table = pa .table (table_data )
1422
+
1423
+ # Each partition will have records with the same category
1424
+ if table .num_rows > 0 :
1425
+ self ._category = table .column ("category" )[0 ].as_py ()
1426
+ self ._total += pa .compute .sum (table .column ("value" )).as_py ()
1427
+ # Don't yield here - accumulate and yield in terminate
1428
+ return iter (())
1429
+
1430
+ def terminate (self ) -> Iterator ["pa.Table" ]:
1431
+ if self ._category is not None :
1432
+ result_table = pa .table (
1433
+ {
1434
+ "partition_key" : pa .array ([self ._category ], type = pa .string ()),
1435
+ "total_value" : pa .array ([self ._total ], type = pa .int64 ()),
1436
+ }
1437
+ )
1438
+ yield result_table
1439
+
1440
+ self .spark .udtf .register ("partition_sum_udtf" , PartitionSumUDTF )
1441
+
1442
+ # Create test data with categories
1443
+ test_data = [("A" , 10 ), ("A" , 20 ), ("B" , 30 ), ("B" , 40 ), ("C" , 50 )]
1444
+ test_df = self .spark .createDataFrame (test_data , "category string, value int" )
1445
+ test_df .createOrReplaceTempView ("partition_test_data" )
1446
+
1447
+ result_df = self .spark .sql (
1448
+ """
1449
+ SELECT * FROM partition_sum_udtf(
1450
+ TABLE(partition_test_data) PARTITION BY category
1451
+ ) ORDER BY partition_key
1452
+ """
1453
+ )
1454
+
1455
+ expected_df = self .spark .createDataFrame (
1456
+ [("A" , 30 ), ("B" , 70 ), ("C" , 50 )], "partition_key string, total_value bigint"
1457
+ )
1458
+ assertDataFrameEqual (result_df , expected_df )
1459
+
1460
+ def test_arrow_udtf_table_partition_by_multiple_columns (self ):
1461
+ @arrow_udtf (returnType = "dept string, status string, count_employees bigint" )
1462
+ class DeptStatusCountUDTF :
1463
+ def __init__ (self ):
1464
+ self ._dept = None
1465
+ self ._status = None
1466
+ self ._count = 0
1467
+
1468
+ def eval (self , table_data : "pa.RecordBatch" ) -> Iterator ["pa.Table" ]:
1469
+ table = pa .table (table_data )
1470
+
1471
+ if table .num_rows > 0 :
1472
+ self ._dept = table .column ("department" )[0 ].as_py ()
1473
+ self ._status = table .column ("status" )[0 ].as_py ()
1474
+ self ._count += table .num_rows
1475
+ # Don't yield here - accumulate and yield in terminate
1476
+ return iter (())
1477
+
1478
+ def terminate (self ) -> Iterator ["pa.Table" ]:
1479
+ if self ._dept is not None and self ._status is not None :
1480
+ result_table = pa .table (
1481
+ {
1482
+ "dept" : pa .array ([self ._dept ], type = pa .string ()),
1483
+ "status" : pa .array ([self ._status ], type = pa .string ()),
1484
+ "count_employees" : pa .array ([self ._count ], type = pa .int64 ()),
1485
+ }
1486
+ )
1487
+ yield result_table
1488
+
1489
+ self .spark .udtf .register ("dept_status_count_udtf" , DeptStatusCountUDTF )
1490
+
1491
+ test_data = [
1492
+ ("IT" , "active" ),
1493
+ ("IT" , "active" ),
1494
+ ("IT" , "inactive" ),
1495
+ ("HR" , "active" ),
1496
+ ("HR" , "inactive" ),
1497
+ ("Finance" , "active" ),
1498
+ ]
1499
+ test_df = self .spark .createDataFrame (test_data , "department string, status string" )
1500
+ test_df .createOrReplaceTempView ("employee_data" )
1501
+
1502
+ result_df = self .spark .sql (
1503
+ """
1504
+ SELECT * FROM dept_status_count_udtf(
1505
+ TABLE(SELECT * FROM employee_data)
1506
+ PARTITION BY (department, status)
1507
+ ) ORDER BY dept, status
1508
+ """
1509
+ )
1510
+
1511
+ expected_df = self .spark .createDataFrame (
1512
+ [
1513
+ ("Finance" , "active" , 1 ),
1514
+ ("HR" , "active" , 1 ),
1515
+ ("HR" , "inactive" , 1 ),
1516
+ ("IT" , "active" , 2 ),
1517
+ ("IT" , "inactive" , 1 ),
1518
+ ],
1519
+ "dept string, status string, count_employees bigint" ,
1520
+ )
1521
+ assertDataFrameEqual (result_df , expected_df )
1522
+
1523
+ def test_arrow_udtf_with_scalar_first_table_second (self ):
1524
+ @arrow_udtf (returnType = "filtered_id bigint" )
1525
+ class ScalarFirstTableSecondUDTF :
1526
+ def eval (
1527
+ self , threshold : "pa.Array" , table_data : "pa.RecordBatch"
1528
+ ) -> Iterator ["pa.Table" ]:
1529
+ assert isinstance (
1530
+ threshold , pa .Array
1531
+ ), f"Expected pa.Array for threshold, got { type (threshold )} "
1532
+ assert isinstance (
1533
+ table_data , pa .RecordBatch
1534
+ ), f"Expected pa.RecordBatch for table_data, got { type (table_data )} "
1535
+
1536
+ threshold_val = threshold [0 ].as_py ()
1537
+
1538
+ # Convert record batch to table
1539
+ table = pa .table (table_data )
1540
+ id_column = table .column ("id" )
1541
+ mask = pa .compute .greater (id_column , pa .scalar (threshold_val ))
1542
+ filtered_table = table .filter (mask )
1543
+
1544
+ if filtered_table .num_rows > 0 :
1545
+ result_table = pa .table (
1546
+ {"filtered_id" : filtered_table .column ("id" )} # Keep original type
1547
+ )
1548
+ yield result_table
1549
+
1550
+ # Test with DataFrame API - scalar first, table second
1551
+ input_df = self .spark .range (8 )
1552
+ result_df = ScalarFirstTableSecondUDTF (lit (4 ), input_df .asTable ())
1553
+ expected_df = self .spark .createDataFrame ([(5 ,), (6 ,), (7 ,)], "filtered_id bigint" )
1554
+ assertDataFrameEqual (result_df , expected_df )
1555
+
1556
+ # Test SQL registration and usage
1557
+ self .spark .udtf .register ("test_scalar_first_table_second_udtf" , ScalarFirstTableSecondUDTF )
1558
+ sql_result_df = self .spark .sql (
1559
+ "SELECT * FROM test_scalar_first_table_second_udtf("
1560
+ "4, TABLE(SELECT id FROM range(0, 8)))"
1561
+ )
1562
+ assertDataFrameEqual (sql_result_df , expected_df )
1563
+
1564
+ def test_arrow_udtf_with_table_argument_in_middle (self ):
1565
+ """Test Arrow UDTF with table argument in the middle of multiple scalar arguments."""
1566
+
1567
+ @arrow_udtf (returnType = "filtered_id bigint" )
1568
+ class TableInMiddleUDTF :
1569
+ def eval (
1570
+ self ,
1571
+ min_threshold : "pa.Array" ,
1572
+ table_data : "pa.RecordBatch" ,
1573
+ max_threshold : "pa.Array" ,
1574
+ ) -> Iterator ["pa.Table" ]:
1575
+ assert isinstance (
1576
+ min_threshold , pa .Array
1577
+ ), f"Expected pa.Array for min_threshold, got { type (min_threshold )} "
1578
+ assert isinstance (
1579
+ table_data , pa .RecordBatch
1580
+ ), f"Expected pa.RecordBatch for table_data, got { type (table_data )} "
1581
+ assert isinstance (
1582
+ max_threshold , pa .Array
1583
+ ), f"Expected pa.Array for max_threshold, got { type (max_threshold )} "
1584
+
1585
+ min_val = min_threshold [0 ].as_py ()
1586
+ max_val = max_threshold [0 ].as_py ()
1587
+
1588
+ # Convert record batch to table
1589
+ table = pa .table (table_data )
1590
+ id_column = table .column ("id" )
1591
+
1592
+ # Filter rows where min_val < id < max_val
1593
+ mask = pa .compute .and_ (
1594
+ pa .compute .greater (id_column , pa .scalar (min_val )),
1595
+ pa .compute .less (id_column , pa .scalar (max_val )),
1596
+ )
1597
+ filtered_table = table .filter (mask )
1598
+
1599
+ if filtered_table .num_rows > 0 :
1600
+ result_table = pa .table (
1601
+ {"filtered_id" : filtered_table .column ("id" )} # Keep original type
1602
+ )
1603
+ yield result_table
1604
+
1605
+ # Test with DataFrame API - scalar, table, scalar
1606
+ input_df = self .spark .range (10 )
1607
+ result_df = TableInMiddleUDTF (lit (2 ), input_df .asTable (), lit (7 ))
1608
+ expected_df = self .spark .createDataFrame ([(3 ,), (4 ,), (5 ,), (6 ,)], "filtered_id bigint" )
1609
+ assertDataFrameEqual (result_df , expected_df )
1610
+
1611
+ # Test SQL registration and usage
1612
+ self .spark .udtf .register ("test_table_in_middle_udtf" , TableInMiddleUDTF )
1613
+ sql_result_df = self .spark .sql (
1614
+ "SELECT * FROM test_table_in_middle_udtf(2, TABLE(SELECT id FROM range(0, 10)), 7)"
1615
+ )
1616
+ assertDataFrameEqual (sql_result_df , expected_df )
1617
+
1618
+ def test_arrow_udtf_with_named_arguments (self ):
1619
+ @arrow_udtf (returnType = "result_id bigint, multiplier_used int" )
1620
+ class NamedArgsUDTF :
1621
+ def eval (
1622
+ self , table_data : "pa.RecordBatch" , multiplier : "pa.Array"
1623
+ ) -> Iterator ["pa.Table" ]:
1624
+ assert isinstance (
1625
+ table_data , pa .RecordBatch
1626
+ ), f"Expected pa.RecordBatch for table_data, got { type (table_data )} "
1627
+ assert isinstance (
1628
+ multiplier , pa .Array
1629
+ ), f"Expected pa.Array for multiplier, got { type (multiplier )} "
1630
+
1631
+ multiplier_val = multiplier [0 ].as_py ()
1632
+
1633
+ # Convert record batch to table
1634
+ table = pa .table (table_data )
1635
+ id_column = table .column ("id" )
1636
+
1637
+ # Multiply each id by the multiplier
1638
+ multiplied_ids = pa .compute .multiply (id_column , pa .scalar (multiplier_val ))
1639
+
1640
+ result_table = pa .table (
1641
+ {
1642
+ "result_id" : multiplied_ids ,
1643
+ "multiplier_used" : pa .array (
1644
+ [multiplier_val ] * table .num_rows , type = pa .int32 ()
1645
+ ),
1646
+ }
1647
+ )
1648
+ yield result_table
1649
+
1650
+ # Test with DataFrame API using named arguments
1651
+ input_df = self .spark .range (3 ) # [0, 1, 2]
1652
+ result_df = NamedArgsUDTF (table_data = input_df .asTable (), multiplier = lit (5 ))
1653
+ expected_df = self .spark .createDataFrame (
1654
+ [(0 , 5 ), (5 , 5 ), (10 , 5 )], "result_id bigint, multiplier_used int"
1655
+ )
1656
+ assertDataFrameEqual (result_df , expected_df )
1657
+
1658
+ # Test with DataFrame API using different named argument order
1659
+ result_df2 = NamedArgsUDTF (multiplier = lit (3 ), table_data = input_df .asTable ())
1660
+ expected_df2 = self .spark .createDataFrame (
1661
+ [(0 , 3 ), (3 , 3 ), (6 , 3 )], "result_id bigint, multiplier_used int"
1662
+ )
1663
+ assertDataFrameEqual (result_df2 , expected_df2 )
1664
+
1665
+ # Test SQL registration and usage with named arguments
1666
+ self .spark .udtf .register ("test_named_args_udtf" , NamedArgsUDTF )
1667
+
1668
+ sql_result_df = self .spark .sql (
1669
+ """
1670
+ SELECT * FROM test_named_args_udtf(
1671
+ table_data => TABLE(SELECT id FROM range(0, 3)),
1672
+ multiplier => 5
1673
+ )
1674
+ """
1675
+ )
1676
+ assertDataFrameEqual (sql_result_df , expected_df )
1677
+
1678
+ sql_result_df2 = self .spark .sql (
1679
+ """
1680
+ SELECT * FROM test_named_args_udtf(
1681
+ multiplier => 3,
1682
+ table_data => TABLE(SELECT id FROM range(0, 3))
1683
+ )
1684
+ """
1685
+ )
1686
+ assertDataFrameEqual (sql_result_df2 , expected_df2 )
1687
+
1413
1688
1414
1689
class ArrowUDTFTests (ArrowUDTFTestsMixin , ReusedSQLTestCase ):
1415
1690
pass
0 commit comments