Skip to content

[SPARK-48516][PYTHON][CONNECT] Turn on Arrow optimization for Python UDFs by default #49482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/docs/source/user_guide/sql/arrow_pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ Arrow Python UDFs are user defined functions that are executed row-by-row, utili
transfer and serialization. To define an Arrow Python UDF, you can use the :meth:`udf` decorator or wrap the function
with the :meth:`udf` method, ensuring the ``useArrow`` parameter is set to True. Additionally, you can enable Arrow
optimization for Python UDFs throughout the entire SparkSession by setting the Spark configuration
``spark.sql.execution.pythonUDF.arrow.enabled`` to true. It's important to note that the Spark configuration takes
effect only when ``useArrow`` is either not set or set to None.
``spark.sql.execution.pythonUDF.arrow.enabled`` to true, which is the default. It's important to note that the Spark
configuration takes effect only when ``useArrow`` is either not set or set to None.

The type hints for Arrow Python UDFs should be specified in the same way as for default, pickled Python UDFs.

Expand Down
2 changes: 1 addition & 1 deletion python/docs/source/user_guide/sql/type_conversions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ are listed below:
- Default
* - spark.sql.execution.pythonUDF.arrow.enabled
- Enable PyArrow in PySpark. See more `here <arrow_pandas.rst>`_.
- False
- True
* - spark.sql.pyspark.inferNestedDictAsStruct.enabled
- When enabled, nested dictionaries are inferred as StructType. Otherwise, they are inferred as MapType.
- False
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def transformSchema(self, schema: StructType) -> StructType:

def _transform(self, dataset: DataFrame) -> DataFrame:
self.transformSchema(dataset.schema)
transformUDF = udf(self.createTransformFunc(), self.outputDataType())
# TODO(SPARK-48515): Use Arrow Python UDF
transformUDF = udf(self.createTransformFunc(), self.outputDataType(), useArrow=False)
transformedDataset = dataset.withColumn(
self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
)
Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3834,12 +3834,13 @@ def __init__(self, models: List[ClassificationModel]):
)

def _transform(self, dataset: DataFrame) -> DataFrame:
# TODO(SPARK-48515): Use Arrow Python UDF
# determine the input columns: these need to be passed through
origCols = dataset.columns

# add an accumulator column to store predictions of all the models
accColName = "mbc$acc" + str(uuid.uuid4())
initUDF = udf(lambda _: [], ArrayType(DoubleType()))
initUDF = udf(lambda _: [], ArrayType(DoubleType()), useArrow=False)
newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))

# persist if underlying dataset is not persistent.
Expand All @@ -3859,6 +3860,7 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
updateUDF = udf(
lambda predictions, prediction: predictions + [prediction.tolist()[1]],
ArrayType(DoubleType()),
useArrow=False,
)
transformedDataset = model.transform(aggregatedDataset).select(*columns)
updatedDataset = transformedDataset.withColumn(
Expand All @@ -3883,7 +3885,7 @@ def func(predictions: Iterable[float]) -> Vector:
predArray.append(x)
return Vectors.dense(predArray)

rawPredictionUDF = udf(func, VectorUDT())
rawPredictionUDF = udf(func, VectorUDT(), useArrow=False)
aggregatedDataset = aggregatedDataset.withColumn(
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName])
)
Expand All @@ -3895,6 +3897,7 @@ def func(predictions: Iterable[float]) -> Vector:
max(enumerate(predictions), key=operator.itemgetter(1))[0]
),
DoubleType(),
useArrow=False,
)
aggregatedDataset = aggregatedDataset.withColumn(
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,8 +906,12 @@ def checker(foldNum: int) -> bool:
from pyspark.sql.connect.udf import UserDefinedFunction
else:
from pyspark.sql.functions import UserDefinedFunction # type: ignore[assignment]
from pyspark.util import PythonEvalType

checker_udf = UserDefinedFunction(checker, BooleanType())
# TODO(SPARK-48515): Use Arrow Python UDF
checker_udf = UserDefinedFunction(
checker, BooleanType(), evalType=PythonEvalType.SQL_BATCHED_UDF
)
for i in range(nFolds):
training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i)))
validation = dataset.filter(
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
UDFRegistration as PySparkUDFRegistration,
UserDefinedFunction as PySparkUserDefinedFunction,
)
from pyspark.sql.utils import has_arrow
from pyspark.errors import PySparkTypeError, PySparkRuntimeError

if TYPE_CHECKING:
Expand All @@ -58,6 +59,7 @@ def _create_py_udf(
returnType: "DataTypeOrString",
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
is_arrow_enabled = False
if useArrow is None:
is_arrow_enabled = False
try:
Expand All @@ -78,6 +80,14 @@ def _create_py_udf(

eval_type: int = PythonEvalType.SQL_BATCHED_UDF

if is_arrow_enabled and not has_arrow:
is_arrow_enabled = False
warnings.warn(
"Arrow optimization failed to enable because PyArrow is not installed. "
"Falling back to a non-Arrow-optimized UDF.",
RuntimeWarning,
)

if is_arrow_enabled:
try:
is_func_with_args = len(getfullargspec(f).args) > 0
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26291,7 +26291,8 @@ def udf(
Defaults to :class:`StringType`.
useArrow : bool, optional
whether to use Arrow to optimize the (de)serialization. When it is None, the
Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect.
Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect,
which is "true" by default.

Examples
--------
Expand Down
12 changes: 10 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
StructType,
_parse_datatype_string,
)
from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.utils import get_active_spark_context, has_arrow
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError
Expand Down Expand Up @@ -118,7 +118,7 @@ def _create_py_udf(
# Note: The values of 'SQL Type' are DDL formatted strings, which can be used as `returnType`s.
# Note: The values inside the table are generated by `repr`. X' means it throws an exception
# during the conversion.

is_arrow_enabled = False
if useArrow is None:
from pyspark.sql import SparkSession

Expand All @@ -131,6 +131,14 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow

if is_arrow_enabled and not has_arrow:
is_arrow_enabled = False
warnings.warn(
"Arrow optimization failed to enable because PyArrow is not installed. "
"Falling back to a non-Arrow-optimized UDF.",
RuntimeWarning,
)

eval_type: int = PythonEvalType.SQL_BATCHED_UDF

if is_arrow_enabled:
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex


has_arrow: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use from pyspark.testing.utils import have_pyarrow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we address this comment @xinrong-meng

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Resolved, thank you

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be a circular import if we do that. Let me follow up with a separate PR instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we can import it inside the function/class body

try:
import pyarrow # noqa: F401

has_arrow = True
except ImportError:
pass


FuncT = TypeVar("FuncT", bound=Callable[..., Any])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3497,7 +3497,7 @@ object SQLConf {
"can only be enabled when the given function takes at least one argument.")
.version("3.4.0")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val PYTHON_UDF_ARROW_CONCURRENCY_LEVEL =
buildConf("spark.sql.execution.pythonUDF.arrow.concurrency.level")
Expand Down