Skip to content

Commit 00d4e03

Browse files
committed
fallback input
1 parent 8070634 commit 00d4e03

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ import scala.collection.mutable.ArrayBuffer
2222

2323
import org.apache.spark.SparkException
2424
import org.apache.spark.api.python.PythonEvalType
25+
import org.apache.spark.internal.Logging
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2728
import org.apache.spark.sql.catalyst.plans.logical._
2829
import org.apache.spark.sql.catalyst.rules.Rule
2930
import org.apache.spark.sql.catalyst.trees.TreePattern._
31+
import org.apache.spark.sql.types.UserDefinedType
3032

3133

3234
/**
@@ -157,7 +159,7 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
157159
* This has the limitation that the input to the Python UDF is not allowed include attributes from
158160
* multiple child operators.
159161
*/
160-
object ExtractPythonUDFs extends Rule[LogicalPlan] {
162+
object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
161163

162164
private type EvalType = Int
163165
private type EvalTypeChecker = EvalType => Boolean
@@ -271,9 +273,21 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
271273
val evaluation = evalType match {
272274
case PythonEvalType.SQL_BATCHED_UDF =>
273275
BatchEvalPython(validUdfs, resultAttrs, child)
274-
case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
275-
| PythonEvalType.SQL_ARROW_BATCHED_UDF =>
276+
case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
276277
ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
278+
case PythonEvalType.SQL_ARROW_BATCHED_UDF =>
279+
// Check if any input columns are UDTs for SQL_ARROW_BATCHED_UDF
280+
val hasUDTInput = child.output.exists(
281+
attr => attr.dataType.isInstanceOf[UserDefinedType[_]])
282+
283+
if (hasUDTInput) {
284+
// Use BatchEvalPython if UDT is detected
285+
logWarning("Arrow optimization disabled due to UDT input. " +
286+
"Falling back to non-Arrow-optimized UDF execution.")
287+
BatchEvalPython(validUdfs, resultAttrs, child)
288+
} else {
289+
ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
290+
}
277291
case _ =>
278292
throw SparkException.internalError("Unexpected UDF evalType")
279293
}

0 commit comments

Comments
 (0)