Skip to content

Commit 702bf80

Browse files
committed
test
1 parent 1a8859d commit 702bf80

File tree

12 files changed

+76
-14
lines changed

12 files changed

+76
-14
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ private[spark] object PythonEvalType {
7070
// Arrow UDFs
7171
val SQL_SCALAR_ARROW_UDF = 250
7272
val SQL_SCALAR_ARROW_ITER_UDF = 251
73+
val SQL_GROUPED_AGG_ARROW_UDF = 252
7374

7475
val SQL_TABLE_UDF = 300
7576
val SQL_ARROW_TABLE_UDF = 301
@@ -101,6 +102,7 @@ private[spark] object PythonEvalType {
101102
// Arrow UDFs
102103
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
103104
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
105+
case SQL_GROUPED_AGG_ARROW_UDF => "SQL_GROUPED_AGG_ARROW_UDF"
104106
}
105107
}
106108

python/pyspark/sql/connect/udf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,14 @@ def register(
280280
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
281281
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
282282
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
283+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
283284
]:
284285
raise PySparkTypeError(
285286
errorClass="INVALID_UDF_EVAL_TYPE",
286287
messageParameters={
287288
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
288-
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
289-
"SQL_GROUPED_AGG_PANDAS_UDF"
289+
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, "
290+
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
290291
},
291292
)
292293
self.sparkSession._client.register_udf(

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ GroupedMapUDFTransformWithStateInitStateType = Literal[214]
6363
# Arrow UDFs
6464
ArrowScalarUDFType = Literal[250]
6565
ArrowScalarIterUDFType = Literal[251]
66+
ArrowGroupedAggUDFType = Literal[252]
6667

6768
class ArrowVariadicScalarToScalarFunction(Protocol):
6869
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...

python/pyspark/sql/pandas/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class ArrowUDFType:
4848

4949
SCALAR_ITER = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
5050

51+
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
52+
5153

5254
def arrow_udf(f=None, returnType=None, functionType=None):
5355
return vectorized_udf(f, returnType, functionType, "arrow")

python/pyspark/sql/pandas/typehints.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PandasGroupedAggUDFType,
2828
ArrowScalarUDFType,
2929
ArrowScalarIterUDFType,
30+
ArrowGroupedAggUDFType,
3031
)
3132

3233

@@ -38,6 +39,7 @@ def infer_eval_type(
3839
"PandasGroupedAggUDFType",
3940
"ArrowScalarUDFType",
4041
"ArrowScalarIterUDFType",
42+
"ArrowGroupedAggUDFType",
4143
]:
4244
"""
4345
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
@@ -175,6 +177,13 @@ def infer_eval_type(
175177
and not check_tuple_annotation(return_annotation)
176178
)
177179

180+
# pa.Array, ... -> Any
181+
is_array_agg = all(a == pa.Array for a in parameters_sig) and (
182+
return_annotation != pa.Array
183+
and not check_iterator_annotation(return_annotation)
184+
and not check_tuple_annotation(return_annotation)
185+
)
186+
178187
if is_series_or_frame:
179188
return PandasUDFType.SCALAR
180189
elif is_arrow_array:
@@ -185,6 +194,8 @@ def infer_eval_type(
185194
return ArrowUDFType.SCALAR_ITER
186195
elif is_series_or_frame_agg:
187196
return PandasUDFType.GROUPED_AGG
197+
elif is_array_agg:
198+
return ArrowUDFType.GROUPED_AGG
188199
else:
189200
raise PySparkNotImplementedError(
190201
errorClass="UNSUPPORTED_SIGNATURE",

python/pyspark/sql/udf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,13 +655,14 @@ def register(
655655
PythonEvalType.SQL_SCALAR_ARROW_UDF,
656656
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
657657
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
658+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
658659
]:
659660
raise PySparkTypeError(
660661
errorClass="INVALID_UDF_EVAL_TYPE",
661662
messageParameters={
662663
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
663-
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
664-
"SQL_GROUPED_AGG_PANDAS_UDF"
664+
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, "
665+
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
665666
},
666667
)
667668
source_udf = _create_udf(

python/pyspark/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
GroupedMapUDFTransformWithStateInitStateType,
6868
ArrowScalarUDFType,
6969
ArrowScalarIterUDFType,
70+
ArrowGroupedAggUDFType,
7071
)
7172
from pyspark.sql._typing import (
7273
SQLArrowBatchedUDFType,
@@ -651,6 +652,7 @@ class PythonEvalType:
651652
# Arrow UDFs
652653
SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250
653654
SQL_SCALAR_ARROW_ITER_UDF: "ArrowScalarIterUDFType" = 251
655+
SQL_GROUPED_AGG_ARROW_UDF: "ArrowGroupedAggUDFType" = 252
654656

655657
SQL_TABLE_UDF: "SQLTableUDFType" = 300
656658
SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301

python/pyspark/worker.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,25 @@ def wrapped(*series):
796796
)
797797

798798

799+
def wrap_grouped_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
800+
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)
801+
802+
arrow_return_type = to_arrow_type(
803+
return_type, prefers_large_types=use_large_var_types(runner_conf)
804+
)
805+
806+
def wrapped(*series):
807+
import pyarrow as pa
808+
809+
result = func(*series)
810+
return pa.array([result])
811+
812+
return (
813+
args_kwargs_offsets,
814+
lambda *a: (wrapped(*a), arrow_return_type),
815+
)
816+
817+
799818
def wrap_window_agg_pandas_udf(
800819
f, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
801820
):
@@ -974,6 +993,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
974993
# The below doesn't support named argument, but shares the same protocol.
975994
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
976995
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
996+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
977997
):
978998
args_offsets = []
979999
kwargs_offsets = {}
@@ -1070,6 +1090,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
10701090
return wrap_grouped_agg_pandas_udf(
10711091
func, args_offsets, kwargs_offsets, return_type, runner_conf
10721092
)
1093+
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
1094+
return wrap_grouped_agg_arrow_udf(
1095+
func, args_offsets, kwargs_offsets, return_type, runner_conf
1096+
)
10731097
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
10741098
return wrap_window_agg_pandas_udf(
10751099
func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
@@ -1815,6 +1839,7 @@ def read_udfs(pickleSer, infile, eval_type):
18151839
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
18161840
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
18171841
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
1842+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
18181843
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
18191844
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
18201845
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
@@ -1911,6 +1936,7 @@ def read_udfs(pickleSer, infile, eval_type):
19111936
elif eval_type in (
19121937
PythonEvalType.SQL_SCALAR_ARROW_UDF,
19131938
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
1939+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
19141940
):
19151941
# Arrow cast for type coercion is disabled by default
19161942
ser = ArrowStreamArrowUDFSerializer(timezone, safecheck, _assign_cols_by_name, False)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,10 @@ case class PythonUDAF(
120120
dataType: DataType,
121121
children: Seq[Expression],
122122
udfDeterministic: Boolean,
123+
evalType: Int,
123124
resultId: ExprId = NamedExpression.newExprId)
124125
extends UnevaluableAggregateFunc with PythonFuncExpression {
125126

126-
override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
127-
128127
override def sql(isDistinct: Boolean): String = {
129128
val distinct = if (isDistinct) "DISTINCT " else ""
130129
s"$name($distinct${children.mkString(", ")})"

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
644644

645645
case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
646646
if aggExpressions.forall(_.aggregateFunction.isInstanceOf[PythonUDAF]) =>
647-
Seq(execution.python.AggregateInPandasExec(
647+
Seq(execution.python.ArrowAggregatePythonExec(
648648
groupingExpressions,
649649
aggExpressions,
650650
resultExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala renamed to sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.io.File
2222
import scala.collection.mutable.ArrayBuffer
2323

2424
import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
25-
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
25+
import org.apache.spark.api.python.ChainedPythonFunctions
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions._
@@ -42,11 +42,12 @@ import org.apache.spark.util.Utils
4242
* finally the executor evaluates any post-aggregation expressions and join the result with the
4343
* grouped key.
4444
*/
45-
case class AggregateInPandasExec(
45+
case class ArrowAggregatePythonExec(
4646
groupingExpressions: Seq[NamedExpression],
4747
aggExpressions: Seq[AggregateExpression],
4848
resultExpressions: Seq[NamedExpression],
49-
child: SparkPlan)
49+
child: SparkPlan,
50+
evalType: Int)
5051
extends UnaryExecNode with PythonSQLMetrics {
5152

5253
override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -173,7 +174,7 @@ case class AggregateInPandasExec(
173174

174175
val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner(
175176
pyFuncs,
176-
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
177+
evalType,
177178
argMetas,
178179
aggInputSchema,
179180
sessionLocalTimeZone,
@@ -216,3 +217,17 @@ case class AggregateInPandasExec(
216217
newIter
217218
}
218219
}
220+
221+
object ArrowAggregatePythonExec {
222+
def apply(
223+
groupingExpressions: Seq[NamedExpression],
224+
aggExpressions: Seq[AggregateExpression],
225+
resultExpressions: Seq[NamedExpression],
226+
child: SparkPlan): ArrowAggregatePythonExec = {
227+
val evalTypes = aggExpressions.map(_.aggregateFunction.asInstanceOf[PythonUDAF].evalType)
228+
assert(evalTypes.distinct.size == 1,
229+
"All aggregate functions must have the same eval type in ArrowAggregatePythonExec")
230+
new ArrowAggregatePythonExec(
231+
groupingExpressions, aggExpressions, resultExpressions, child, evalTypes.head)
232+
}
233+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ case class UserDefinedPythonFunction(
5050
|| pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF
5151
|| pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
5252
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
53-
|| pythonEvalType == PythonEvalType.SQL_SCALAR_ARROW_UDF) {
53+
|| pythonEvalType == PythonEvalType.SQL_SCALAR_ARROW_UDF
54+
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF) {
5455
/*
5556
* Check if the named arguments:
5657
* - don't have duplicated names
@@ -61,8 +62,9 @@ case class UserDefinedPythonFunction(
6162
throw QueryCompilationErrors.namedArgumentsNotSupported(name)
6263
}
6364

64-
if (pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
65-
PythonUDAF(name, func, dataType, e, udfDeterministic)
65+
if (pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
66+
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF) {
67+
PythonUDAF(name, func, dataType, e, udfDeterministic, pythonEvalType)
6668
} else {
6769
PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
6870
}

0 commit comments

Comments
 (0)