@@ -22,11 +22,13 @@ import scala.collection.mutable.ArrayBuffer
22
22
23
23
import org .apache .spark .SparkException
24
24
import org .apache .spark .api .python .PythonEvalType
25
+ import org .apache .spark .internal .Logging
25
26
import org .apache .spark .sql .catalyst .expressions ._
26
27
import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
27
28
import org .apache .spark .sql .catalyst .plans .logical ._
28
29
import org .apache .spark .sql .catalyst .rules .Rule
29
30
import org .apache .spark .sql .catalyst .trees .TreePattern ._
31
+ import org .apache .spark .sql .types .UserDefinedType
30
32
31
33
32
34
/**
@@ -157,7 +159,7 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
157
159
* This has the limitation that the input to the Python UDF is not allowed include attributes from
158
160
* multiple child operators.
159
161
*/
160
- object ExtractPythonUDFs extends Rule [LogicalPlan ] {
162
+ object ExtractPythonUDFs extends Rule [LogicalPlan ] with Logging {
161
163
162
164
private type EvalType = Int
163
165
private type EvalTypeChecker = EvalType => Boolean
@@ -271,9 +273,21 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
271
273
val evaluation = evalType match {
272
274
case PythonEvalType .SQL_BATCHED_UDF =>
273
275
BatchEvalPython (validUdfs, resultAttrs, child)
274
- case PythonEvalType .SQL_SCALAR_PANDAS_UDF | PythonEvalType .SQL_SCALAR_PANDAS_ITER_UDF
275
- | PythonEvalType .SQL_ARROW_BATCHED_UDF =>
276
+ case PythonEvalType .SQL_SCALAR_PANDAS_UDF | PythonEvalType .SQL_SCALAR_PANDAS_ITER_UDF =>
276
277
ArrowEvalPython (validUdfs, resultAttrs, child, evalType)
278
+ case PythonEvalType .SQL_ARROW_BATCHED_UDF =>
279
+ // Check if any input columns are UDTs for SQL_ARROW_BATCHED_UDF
280
+ val hasUDTInput = child.output.exists(
281
+ attr => attr.dataType.isInstanceOf [UserDefinedType [_]])
282
+
283
+ if (hasUDTInput) {
284
+ // Use BatchEvalPython if UDT is detected
285
+ logWarning(" Arrow optimization disabled due to UDT input. " +
286
+ " Falling back to non-Arrow-optimized UDF execution." )
287
+ BatchEvalPython (validUdfs, resultAttrs, child)
288
+ } else {
289
+ ArrowEvalPython (validUdfs, resultAttrs, child, evalType)
290
+ }
277
291
case _ =>
278
292
throw SparkException .internalError(" Unexpected UDF evalType" )
279
293
}
0 commit comments