Skip to content

Commit 2de9be2

Browse files
xinrong-mengdongjoon-hyun
authored andcommitted
[SPARK-48516][PYTHON][CONNECT] Turn on Arrow optimization for Python UDFs by default
### What changes were proposed in this pull request? Turn on Arrow optimization for Python UDFs by default ### Why are the changes needed? Arrow optimization was introduced in 3.4.0. See [SPARK-40307](https://issues.apache.org/jira/browse/SPARK-40307) for more context. Arrow-optimized Python UDF is approximately 1.6 times faster than the original pickled Python UDF. More details can be found in [this blog post](https://www.databricks.com/blog/arrow-optimized-python-udfs-apache-sparktm-35). In version 4.0.0, we propose enabling the optimization by default. If PyArrow is not installed, it will fall back to the original pickled Python UDF. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #49482 from xinrong-meng/arrow_on. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 59dd406) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 1a50e21 commit 2de9be2

File tree

10 files changed

+47
-11
lines changed

10 files changed

+47
-11
lines changed

python/docs/source/user_guide/sql/arrow_pandas.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ Arrow Python UDFs are user defined functions that are executed row-by-row, utili
356356
transfer and serialization. To define an Arrow Python UDF, you can use the :meth:`udf` decorator or wrap the function
357357
with the :meth:`udf` method, ensuring the ``useArrow`` parameter is set to True. Additionally, you can enable Arrow
358358
optimization for Python UDFs throughout the entire SparkSession by setting the Spark configuration
359-
``spark.sql.execution.pythonUDF.arrow.enabled`` to true. It's important to note that the Spark configuration takes
360-
effect only when ``useArrow`` is either not set or set to None.
359+
``spark.sql.execution.pythonUDF.arrow.enabled`` to true, which is the default. It's important to note that the Spark
360+
configuration takes effect only when ``useArrow`` is either not set or set to None.
361361

362362
The type hints for Arrow Python UDFs should be specified in the same way as for default, pickled Python UDFs.
363363

python/docs/source/user_guide/sql/type_conversions.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ are listed below:
5757
- Default
5858
* - spark.sql.execution.pythonUDF.arrow.enabled
5959
- Enable PyArrow in PySpark. See more `here <arrow_pandas.rst>`_.
60-
- False
60+
- True
6161
* - spark.sql.pyspark.inferNestedDictAsStruct.enabled
6262
- When enabled, nested dictionaries are inferred as StructType. Otherwise, they are inferred as MapType.
6363
- False

python/pyspark/ml/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ def transformSchema(self, schema: StructType) -> StructType:
328328

329329
def _transform(self, dataset: DataFrame) -> DataFrame:
330330
self.transformSchema(dataset.schema)
331-
transformUDF = udf(self.createTransformFunc(), self.outputDataType())
331+
# TODO(SPARK-48515): Use Arrow Python UDF
332+
transformUDF = udf(self.createTransformFunc(), self.outputDataType(), useArrow=False)
332333
transformedDataset = dataset.withColumn(
333334
self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
334335
)

python/pyspark/ml/classification.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3835,12 +3835,13 @@ def __init__(self, models: List[ClassificationModel]):
38353835
)
38363836

38373837
def _transform(self, dataset: DataFrame) -> DataFrame:
3838+
# TODO(SPARK-48515): Use Arrow Python UDF
38383839
# determine the input columns: these need to be passed through
38393840
origCols = dataset.columns
38403841

38413842
# add an accumulator column to store predictions of all the models
38423843
accColName = "mbc$acc" + str(uuid.uuid4())
3843-
initUDF = udf(lambda _: [], ArrayType(DoubleType()))
3844+
initUDF = udf(lambda _: [], ArrayType(DoubleType()), useArrow=False)
38443845
newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
38453846

38463847
# persist if underlying dataset is not persistent.
@@ -3860,6 +3861,7 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
38603861
updateUDF = udf(
38613862
lambda predictions, prediction: predictions + [prediction.tolist()[1]],
38623863
ArrayType(DoubleType()),
3864+
useArrow=False,
38633865
)
38643866
transformedDataset = model.transform(aggregatedDataset).select(*columns)
38653867
updatedDataset = transformedDataset.withColumn(
@@ -3884,7 +3886,7 @@ def func(predictions: Iterable[float]) -> Vector:
38843886
predArray.append(x)
38853887
return Vectors.dense(predArray)
38863888

3887-
rawPredictionUDF = udf(func, VectorUDT())
3889+
rawPredictionUDF = udf(func, VectorUDT(), useArrow=False)
38883890
aggregatedDataset = aggregatedDataset.withColumn(
38893891
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName])
38903892
)
@@ -3896,6 +3898,7 @@ def func(predictions: Iterable[float]) -> Vector:
38963898
max(enumerate(predictions), key=operator.itemgetter(1))[0]
38973899
),
38983900
DoubleType(),
3901+
useArrow=False,
38993902
)
39003903
aggregatedDataset = aggregatedDataset.withColumn(
39013904
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])

python/pyspark/ml/tuning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,8 +906,12 @@ def checker(foldNum: int) -> bool:
906906
from pyspark.sql.connect.udf import UserDefinedFunction
907907
else:
908908
from pyspark.sql.functions import UserDefinedFunction # type: ignore[assignment]
909+
from pyspark.util import PythonEvalType
909910

910-
checker_udf = UserDefinedFunction(checker, BooleanType())
911+
# TODO(SPARK-48515): Use Arrow Python UDF
912+
checker_udf = UserDefinedFunction(
913+
checker, BooleanType(), evalType=PythonEvalType.SQL_BATCHED_UDF
914+
)
911915
for i in range(nFolds):
912916
training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i)))
913917
validation = dataset.filter(

python/pyspark/sql/connect/udf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
UDFRegistration as PySparkUDFRegistration,
4242
UserDefinedFunction as PySparkUserDefinedFunction,
4343
)
44+
from pyspark.sql.utils import has_arrow
4445
from pyspark.errors import PySparkTypeError, PySparkRuntimeError
4546

4647
if TYPE_CHECKING:
@@ -58,6 +59,7 @@ def _create_py_udf(
5859
returnType: "DataTypeOrString",
5960
useArrow: Optional[bool] = None,
6061
) -> "UserDefinedFunctionLike":
62+
is_arrow_enabled = False
6163
if useArrow is None:
6264
is_arrow_enabled = False
6365
try:
@@ -78,6 +80,14 @@ def _create_py_udf(
7880

7981
eval_type: int = PythonEvalType.SQL_BATCHED_UDF
8082

83+
if is_arrow_enabled and not has_arrow:
84+
is_arrow_enabled = False
85+
warnings.warn(
86+
"Arrow optimization failed to enable because PyArrow is not installed. "
87+
"Falling back to a non-Arrow-optimized UDF.",
88+
RuntimeWarning,
89+
)
90+
8191
if is_arrow_enabled:
8292
try:
8393
is_func_with_args = len(getfullargspec(f).args) > 0

python/pyspark/sql/functions/builtin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26308,7 +26308,8 @@ def udf(
2630826308
Defaults to :class:`StringType`.
2630926309
useArrow : bool, optional
2631026310
whether to use Arrow to optimize the (de)serialization. When it is None, the
26311-
Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect.
26311+
Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect,
26312+
which is "true" by default.
2631226313

2631326314
Examples
2631426315
--------

python/pyspark/sql/udf.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
StructType,
3535
_parse_datatype_string,
3636
)
37-
from pyspark.sql.utils import get_active_spark_context
37+
from pyspark.sql.utils import get_active_spark_context, has_arrow
3838
from pyspark.sql.pandas.types import to_arrow_type
3939
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
4040
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError
@@ -118,7 +118,7 @@ def _create_py_udf(
118118
# Note: The values of 'SQL Type' are DDL formatted strings, which can be used as `returnType`s.
119119
# Note: The values inside the table are generated by `repr`. X' means it throws an exception
120120
# during the conversion.
121-
121+
is_arrow_enabled = False
122122
if useArrow is None:
123123
from pyspark.sql import SparkSession
124124

@@ -131,6 +131,14 @@ def _create_py_udf(
131131
else:
132132
is_arrow_enabled = useArrow
133133

134+
if is_arrow_enabled and not has_arrow:
135+
is_arrow_enabled = False
136+
warnings.warn(
137+
"Arrow optimization failed to enable because PyArrow is not installed. "
138+
"Falling back to a non-Arrow-optimized UDF.",
139+
RuntimeWarning,
140+
)
141+
134142
eval_type: int = PythonEvalType.SQL_BATCHED_UDF
135143

136144
if is_arrow_enabled:

python/pyspark/sql/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@
6363
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
6464

6565

66+
has_arrow: bool = False
67+
try:
68+
import pyarrow # noqa: F401
69+
70+
has_arrow = True
71+
except ImportError:
72+
pass
73+
74+
6675
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
6776

6877

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3513,7 +3513,7 @@ object SQLConf {
35133513
"can only be enabled when the given function takes at least one argument.")
35143514
.version("3.4.0")
35153515
.booleanConf
3516-
.createWithDefault(false)
3516+
.createWithDefault(true)
35173517

35183518
val PYTHON_UDF_ARROW_CONCURRENCY_LEVEL =
35193519
buildConf("spark.sql.execution.pythonUDF.arrow.concurrency.level")

0 commit comments

Comments
 (0)