diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py index 006084e88cd6e..02a5e39bf0b02 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py @@ -699,7 +699,7 @@ def eval(self, input_val: int): expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int") assertDataFrameEqual(result_df, expected_df) - def test_arrow_udtf_with_named_arguments(self): + def test_arrow_udtf_with_named_arguments_scalar_only(self): @arrow_udtf(returnType="x int, y int, sum int") class NamedArgsUDTF: def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: @@ -1410,6 +1410,281 @@ def eval(self, table_arg, struct_arg) -> Iterator["pa.Table"]: assert row.table_is_batch is True assert row.struct_is_array is True + def test_arrow_udtf_table_partition_by_single_column(self): + @arrow_udtf(returnType="partition_key string, total_value bigint") + class PartitionSumUDTF: + def __init__(self): + self._category = None + self._total = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + + # Each partition will have records with the same category + if table.num_rows > 0: + self._category = table.column("category")[0].as_py() + self._total += pa.compute.sum(table.column("value")).as_py() + # Don't yield here - accumulate and yield in terminate + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._category is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._category], type=pa.string()), + "total_value": pa.array([self._total], type=pa.int64()), + } + ) + yield result_table + + self.spark.udtf.register("partition_sum_udtf", PartitionSumUDTF) + + # Create test data with categories + test_data = [("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50)] + test_df = self.spark.createDataFrame(test_data, "category string, value int") + test_df.createOrReplaceTempView("partition_test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM partition_sum_udtf( + TABLE(partition_test_data) PARTITION BY category + ) ORDER BY partition_key + """ + ) + + expected_df = self.spark.createDataFrame( + [("A", 30), ("B", 70), ("C", 50)], "partition_key string, total_value bigint" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_table_partition_by_multiple_columns(self): + @arrow_udtf(returnType="dept string, status string, count_employees bigint") + class DeptStatusCountUDTF: + def __init__(self): + self._dept = None + self._status = None + self._count = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + + if table.num_rows > 0: + self._dept = table.column("department")[0].as_py() + self._status = table.column("status")[0].as_py() + self._count += table.num_rows + # Don't yield here - accumulate and yield in terminate + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._dept is not None and self._status is not None: + result_table = pa.table( + { + "dept": pa.array([self._dept], type=pa.string()), + "status": pa.array([self._status], type=pa.string()), + "count_employees": pa.array([self._count], type=pa.int64()), + } + ) + yield result_table + + self.spark.udtf.register("dept_status_count_udtf", DeptStatusCountUDTF) + + test_data = [ + ("IT", "active"), + ("IT", "active"), + ("IT", "inactive"), + ("HR", "active"), + ("HR", "inactive"), + ("Finance", "active"), + ] + test_df = self.spark.createDataFrame(test_data, "department string, status string") + test_df.createOrReplaceTempView("employee_data") + + result_df = self.spark.sql( + """ + SELECT * FROM dept_status_count_udtf( + TABLE(SELECT * FROM employee_data) + PARTITION BY (department, status) + ) ORDER BY dept, status + """ + ) + + expected_df = self.spark.createDataFrame( + [ + ("Finance", "active", 1), + ("HR", "active", 1), + ("HR", "inactive", 1), + ("IT", "active", 2), + ("IT", "inactive", 1), + ], + "dept string, status string, count_employees bigint", + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_scalar_first_table_second(self): + @arrow_udtf(returnType="filtered_id bigint") + class ScalarFirstTableSecondUDTF: + def eval( + self, threshold: "pa.Array", table_data: "pa.RecordBatch" + ) -> Iterator["pa.Table"]: + assert isinstance( + threshold, pa.Array + ), f"Expected pa.Array for threshold, got {type(threshold)}" + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + + threshold_val = threshold[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + mask = pa.compute.greater(id_column, pa.scalar(threshold_val)) + filtered_table = table.filter(mask) + + if filtered_table.num_rows > 0: + result_table = pa.table( + {"filtered_id": filtered_table.column("id")} # Keep original type + ) + yield result_table + + # Test with DataFrame API - scalar first, table second + input_df = self.spark.range(8) + result_df = ScalarFirstTableSecondUDTF(lit(4), input_df.asTable()) + expected_df = self.spark.createDataFrame([(5,), (6,), (7,)], "filtered_id bigint") + assertDataFrameEqual(result_df, expected_df) + + # Test SQL registration and usage + self.spark.udtf.register("test_scalar_first_table_second_udtf", ScalarFirstTableSecondUDTF) + sql_result_df = self.spark.sql( + "SELECT * FROM test_scalar_first_table_second_udtf(" + "4, TABLE(SELECT id FROM range(0, 8)))" + ) + assertDataFrameEqual(sql_result_df, expected_df) + + def test_arrow_udtf_with_table_argument_in_middle(self): + """Test Arrow UDTF with table argument in the middle of multiple scalar arguments.""" + + @arrow_udtf(returnType="filtered_id bigint") + class TableInMiddleUDTF: + def eval( + self, + min_threshold: "pa.Array", + table_data: "pa.RecordBatch", + max_threshold: "pa.Array", + ) -> Iterator["pa.Table"]: + assert isinstance( + min_threshold, pa.Array + ), f"Expected pa.Array for min_threshold, got {type(min_threshold)}" + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + assert isinstance( + max_threshold, pa.Array + ), f"Expected pa.Array for max_threshold, got {type(max_threshold)}" + + min_val = min_threshold[0].as_py() + max_val = max_threshold[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + + # Filter rows where min_val < id < max_val + mask = pa.compute.and_( + pa.compute.greater(id_column, pa.scalar(min_val)), + pa.compute.less(id_column, pa.scalar(max_val)), + ) + filtered_table = table.filter(mask) + + if filtered_table.num_rows > 0: + result_table = pa.table( + {"filtered_id": filtered_table.column("id")} # Keep original type + ) + yield result_table + + # Test with DataFrame API - scalar, table, scalar + input_df = self.spark.range(10) + result_df = TableInMiddleUDTF(lit(2), input_df.asTable(), lit(7)) + expected_df = self.spark.createDataFrame([(3,), (4,), (5,), (6,)], "filtered_id bigint") + assertDataFrameEqual(result_df, expected_df) + + # Test SQL registration and usage + self.spark.udtf.register("test_table_in_middle_udtf", TableInMiddleUDTF) + sql_result_df = self.spark.sql( + "SELECT * FROM test_table_in_middle_udtf(2, TABLE(SELECT id FROM range(0, 10)), 7)" + ) + assertDataFrameEqual(sql_result_df, expected_df) + + def test_arrow_udtf_with_named_arguments(self): + @arrow_udtf(returnType="result_id bigint, multiplier_used int") + class NamedArgsUDTF: + def eval( + self, table_data: "pa.RecordBatch", multiplier: "pa.Array" + ) -> Iterator["pa.Table"]: + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + assert isinstance( + multiplier, pa.Array + ), f"Expected pa.Array for multiplier, got {type(multiplier)}" + + multiplier_val = multiplier[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + + # Multiply each id by the multiplier + multiplied_ids = pa.compute.multiply(id_column, pa.scalar(multiplier_val)) + + result_table = pa.table( + { + "result_id": multiplied_ids, + "multiplier_used": pa.array( + [multiplier_val] * table.num_rows, type=pa.int32() + ), + } + ) + yield result_table + + # Test with DataFrame API using named arguments + input_df = self.spark.range(3) # [0, 1, 2] + result_df = NamedArgsUDTF(table_data=input_df.asTable(), multiplier=lit(5)) + expected_df = self.spark.createDataFrame( + [(0, 5), (5, 5), (10, 5)], "result_id bigint, multiplier_used int" + ) + assertDataFrameEqual(result_df, expected_df) + + # Test with DataFrame API using different named argument order + result_df2 = NamedArgsUDTF(multiplier=lit(3), table_data=input_df.asTable()) + expected_df2 = self.spark.createDataFrame( + [(0, 3), (3, 3), (6, 3)], "result_id bigint, multiplier_used int" + ) + assertDataFrameEqual(result_df2, expected_df2) + + # Test SQL registration and usage with named arguments + self.spark.udtf.register("test_named_args_udtf", NamedArgsUDTF) + + sql_result_df = self.spark.sql( + """ + SELECT * FROM test_named_args_udtf( + table_data => TABLE(SELECT id FROM range(0, 3)), + multiplier => 5 + ) + """ + ) + assertDataFrameEqual(sql_result_df, expected_df) + + sql_result_df2 = self.spark.sql( + """ + SELECT * FROM test_named_args_udtf( + multiplier => 3, + table_data => TABLE(SELECT id FROM range(0, 3)) + ) + """ + ) + assertDataFrameEqual(sql_result_df2, expected_df2) + class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase): pass