Skip to content

[WIP][PYTHON] Arrow UDF for aggregation #51292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ private[spark] object PythonEvalType {
// Arrow UDFs
val SQL_SCALAR_ARROW_UDF = 250
val SQL_SCALAR_ARROW_ITER_UDF = 251
val SQL_GROUPED_AGG_ARROW_UDF = 252

val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301
Expand Down Expand Up @@ -101,6 +102,7 @@ private[spark] object PythonEvalType {
// Arrow UDFs
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
case SQL_GROUPED_AGG_ARROW_UDF => "SQL_GROUPED_AGG_ARROW_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def __hash__(self):
"pyspark.sql.tests.arrow.test_arrow_grouped_map",
"pyspark.sql.tests.arrow.test_arrow_python_udf",
"pyspark.sql.tests.arrow.test_arrow_udf",
"pyspark.sql.tests.arrow.test_arrow_udf_grouped_agg",
"pyspark.sql.tests.arrow.test_arrow_udf_scalar",
"pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
"pyspark.sql.tests.pandas.test_pandas_grouped_map",
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,14 @@ def register(
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
]:
raise PySparkTypeError(
errorClass="INVALID_UDF_EVAL_TYPE",
messageParameters={
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
"SQL_GROUPED_AGG_PANDAS_UDF"
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, "
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
},
)
self.sparkSession._client.register_udf(
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ GroupedMapUDFTransformWithStateInitStateType = Literal[214]
# Arrow UDFs
ArrowScalarUDFType = Literal[250]
ArrowScalarIterUDFType = Literal[251]
ArrowGroupedAggUDFType = Literal[252]

class ArrowVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class ArrowUDFType:

SCALAR_ITER = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF

GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF


def arrow_udf(f=None, returnType=None, functionType=None):
return vectorized_udf(f, returnType, functionType, "arrow")
Expand Down Expand Up @@ -454,6 +456,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
if kind == "arrow" and eval_type not in [
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
None,
]: # None means it should infer the type from type hints.
raise PySparkTypeError(
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/pandas/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PandasGroupedAggUDFType,
ArrowScalarUDFType,
ArrowScalarIterUDFType,
ArrowGroupedAggUDFType,
)


Expand All @@ -38,6 +39,7 @@ def infer_eval_type(
"PandasGroupedAggUDFType",
"ArrowScalarUDFType",
"ArrowScalarIterUDFType",
"ArrowGroupedAggUDFType",
]:
"""
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
Expand Down Expand Up @@ -175,6 +177,13 @@ def infer_eval_type(
and not check_tuple_annotation(return_annotation)
)

# pa.Array, ... -> Any
is_array_agg = all(a == pa.Array for a in parameters_sig) and (
return_annotation != pa.Array
and not check_iterator_annotation(return_annotation)
and not check_tuple_annotation(return_annotation)
)

if is_series_or_frame:
return PandasUDFType.SCALAR
elif is_arrow_array:
Expand All @@ -185,6 +194,8 @@ def infer_eval_type(
return ArrowUDFType.SCALAR_ITER
elif is_series_or_frame_agg:
return PandasUDFType.GROUPED_AGG
elif is_array_agg:
return ArrowUDFType.GROUPED_AGG
else:
raise PySparkNotImplementedError(
errorClass="UNSUPPORTED_SIGNATURE",
Expand Down
Loading