Skip to content

Commit

Permalink
[SPARK-50869][ML][CONNECT][PYTHON] Support evaluators on ML Connet
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds support Evaluator on ML Connect:

- org.apache.spark.ml.evaluation.RegressionEvaluator
- org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
- org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
- org.apache.spark.ml.evaluation.ClusteringEvaluator
- org.apache.spark.ml.evaluation.RankingEvaluator

### Why are the changes needed?
for parity with spark classic

### Does this PR introduce _any_ user-facing change?
Yes, new evaluators supported on ML connect

### How was this patch tested?
The newly added tests can pass

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49547 from wbo4958/evaluator.ml.connect.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Jan 18, 2025
1 parent 205e382 commit ef363b6
Show file tree
Hide file tree
Showing 17 changed files with 778 additions and 87 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_parity_classification",
"pyspark.ml.tests.connect.test_parity_evaluation",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml evaluators.
# So register the supported evaluator here if you're trying to add a new one.

org.apache.spark.ml.evaluation.RegressionEvaluator
org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
org.apache.spark.ml.evaluation.ClusteringEvaluator
org.apache.spark.ml.evaluation.RankingEvaluator
17 changes: 17 additions & 0 deletions python/pyspark/ml/connect/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def sc(self) -> "SparkContext":

def save(self, path: str) -> None:
from pyspark.ml.wrapper import JavaModel, JavaEstimator
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.sql.connect.session import SparkSession

session = SparkSession.getActiveSession()
Expand Down Expand Up @@ -69,6 +70,19 @@ def save(self, path: str) -> None:
should_overwrite=self.shouldOverwrite,
options=self.optionMap,
)
elif isinstance(self._instance, JavaEvaluator):
evaluator = cast("JavaEvaluator", self._instance)
params = serialize_ml_params(evaluator, session.client)
assert isinstance(evaluator._java_obj, str)
writer = pb2.MlCommand.Write(
operator=pb2.MlOperator(
name=evaluator._java_obj, uid=evaluator.uid, type=pb2.MlOperator.EVALUATOR
),
params=params,
path=path,
should_overwrite=self.shouldOverwrite,
options=self.optionMap,
)
else:
raise NotImplementedError(f"Unsupported writing for {self._instance}")

Expand All @@ -85,6 +99,7 @@ def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
def load(self, path: str) -> RL:
from pyspark.sql.connect.session import SparkSession
from pyspark.ml.wrapper import JavaModel, JavaEstimator
from pyspark.ml.evaluation import JavaEvaluator

session = SparkSession.getActiveSession()
assert session is not None
Expand All @@ -99,6 +114,8 @@ def load(self, path: str) -> RL:
ml_type = pb2.MlOperator.MODEL
elif issubclass(self._clazz, JavaEstimator):
ml_type = pb2.MlOperator.ESTIMATOR
elif issubclass(self._clazz, JavaEvaluator):
ml_type = pb2.MlOperator.EVALUATOR
else:
raise ValueError(f"Unsupported reading for {java_qualified_class_name}")

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
HasWeightCol,
)
from pyspark.ml.common import inherit_doc
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_evaluate
from pyspark.sql.dataframe import DataFrame

if TYPE_CHECKING:
Expand Down Expand Up @@ -128,6 +128,7 @@ class JavaEvaluator(JavaParams, Evaluator, metaclass=ABCMeta):
implementations.
"""

@try_remote_evaluate
def _evaluate(self, dataset: DataFrame) -> float:
"""
Evaluates the output.
Expand Down
49 changes: 49 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import unittest

from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin
from pyspark.sql import SparkSession


class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def test_assert_remote_mode(self):
from pyspark.sql import is_remote

self.assertTrue(is_remote())

def tearDown(self) -> None:
self.spark.stop()


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_evaluation import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit ef363b6

Please sign in to comment.