Skip to content

Commit 68615c4

Browse files
[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect
### What changes were proposed in this pull request? Support OneVsRest on Connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49693 from zhengruifeng/ml_connect_ovr. Lead-authored-by: Ruifeng Zheng <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]> (cherry picked from commit 22bac2e) Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent d3fad47 commit 68615c4

File tree

5 files changed

+251
-19
lines changed

5 files changed

+251
-19
lines changed

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def __hash__(self):
676676
"pyspark.ml.tests.test_persistence",
677677
"pyspark.ml.tests.test_pipeline",
678678
"pyspark.ml.tests.test_tuning",
679+
"pyspark.ml.tests.test_ovr",
679680
"pyspark.ml.tests.test_stat",
680681
"pyspark.ml.tests.test_training_summary",
681682
"pyspark.ml.tests.tuning.test_tuning",
@@ -1129,6 +1130,7 @@ def __hash__(self):
11291130
"pyspark.ml.tests.connect.test_parity_feature",
11301131
"pyspark.ml.tests.connect.test_parity_pipeline",
11311132
"pyspark.ml.tests.connect.test_parity_tuning",
1133+
"pyspark.ml.tests.connect.test_parity_ovr",
11321134
],
11331135
excluded_python_implementations=[
11341136
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and

python/pyspark/ml/classification.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
MLWriter,
8686
MLWritable,
8787
HasTrainingSummary,
88+
try_remote_read,
89+
try_remote_write,
8890
try_remote_attribute_relation,
8991
)
9092
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
@@ -94,6 +96,7 @@
9496
from pyspark.sql.functions import udf, when
9597
from pyspark.sql.types import ArrayType, DoubleType
9698
from pyspark.storagelevel import StorageLevel
99+
from pyspark.sql.utils import is_remote
97100

98101
if TYPE_CHECKING:
99102
from pyspark.ml._typing import P, ParamMap
@@ -3572,31 +3575,45 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
35723575
if handlePersistence:
35733576
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
35743577

3575-
def trainSingleClass(index: int) -> CM:
3576-
binaryLabelCol = "mc2b$" + str(index)
3577-
trainingDataset = multiclassLabeled.withColumn(
3578-
binaryLabelCol,
3579-
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3580-
)
3581-
paramMap = dict(
3582-
[
3583-
(classifier.labelCol, binaryLabelCol),
3584-
(classifier.featuresCol, featuresCol),
3585-
(classifier.predictionCol, predictionCol),
3586-
]
3587-
)
3588-
if weightCol:
3589-
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3590-
return classifier.fit(trainingDataset, paramMap)
3578+
def _oneClassFitTasks(numClasses: int):
3579+
indices = iter(range(numClasses))
3580+
3581+
def trainSingleClass() -> CM:
3582+
index = next(indices)
3583+
3584+
binaryLabelCol = "mc2b$" + str(index)
3585+
trainingDataset = multiclassLabeled.withColumn(
3586+
binaryLabelCol,
3587+
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3588+
)
3589+
paramMap = dict(
3590+
[
3591+
(classifier.labelCol, binaryLabelCol),
3592+
(classifier.featuresCol, featuresCol),
3593+
(classifier.predictionCol, predictionCol),
3594+
]
3595+
)
3596+
if weightCol:
3597+
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3598+
return index, classifier.fit(trainingDataset, paramMap)
35913599

3600+
return [trainSingleClass] * numClasses
3601+
3602+
tasks = map(
3603+
inheritable_thread_target(dataset.sparkSession),
3604+
_oneClassFitTasks(numClasses),
3605+
)
35923606
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
35933607

3594-
models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
3608+
subModels = [None] * numClasses
3609+
for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
3610+
assert subModels is not None
3611+
subModels[j] = subModel
35953612

35963613
if handlePersistence:
35973614
multiclassLabeled.unpersist()
35983615

3599-
return self._copyValues(OneVsRestModel(models=models))
3616+
return self._copyValues(OneVsRestModel(models=subModels))
36003617

36013618
def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
36023619
"""
@@ -3671,9 +3688,11 @@ def _to_java(self) -> "JavaObject":
36713688
return _java_obj
36723689

36733690
@classmethod
3691+
@try_remote_read
36743692
def read(cls) -> "OneVsRestReader":
36753693
return OneVsRestReader(cls)
36763694

3695+
@try_remote_write
36773696
def write(self) -> MLWriter:
36783697
if isinstance(self.getClassifier(), JavaMLWritable):
36793698
return JavaMLWriter(self) # type: ignore[arg-type]
@@ -3787,7 +3806,7 @@ def __init__(self, models: List[ClassificationModel]):
37873806
from pyspark.core.context import SparkContext
37883807

37893808
self.models = models
3790-
if not isinstance(models[0], JavaMLWritable):
3809+
if is_remote() or not isinstance(models[0], JavaMLWritable):
37913810
return
37923811
# set java instance
37933812
java_models = [cast(_JavaClassificationModel, model)._to_java() for model in self.models]
@@ -3955,9 +3974,11 @@ def _to_java(self) -> "JavaObject":
39553974
return _java_obj
39563975

39573976
@classmethod
3977+
@try_remote_read
39583978
def read(cls) -> "OneVsRestModelReader":
39593979
return OneVsRestModelReader(cls)
39603980

3981+
@try_remote_write
39613982
def write(self) -> MLWriter:
39623983
if all(
39633984
map(

python/pyspark/ml/connect/readwrite.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def saveInstance(
9595
from pyspark.ml.evaluation import JavaEvaluator
9696
from pyspark.ml.pipeline import Pipeline, PipelineModel
9797
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
98+
from pyspark.ml.classification import OneVsRest, OneVsRestModel
9899

99100
# Spark Connect ML is built on scala Spark.ML, that means we're only
100101
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -187,6 +188,27 @@ def saveInstance(
187188
warnings.warn("Overwrite doesn't take effect for TrainValidationSplitModel")
188189
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, optionMap, session)
189190
tvsm_writer.save(path)
191+
elif isinstance(instance, OneVsRest):
192+
from pyspark.ml.classification import OneVsRestWriter
193+
194+
if shouldOverwrite:
195+
# TODO(SPARK-50954): Support client side model path overwrite
196+
warnings.warn("Overwrite doesn't take effect for OneVsRest")
197+
198+
writer = OneVsRestWriter(instance)
199+
writer.session(session)
200+
writer.save(path)
201+
# _OneVsRestSharedReadWrite.saveImpl(self.instance, self.sparkSession, path)
202+
elif isinstance(instance, OneVsRestModel):
203+
from pyspark.ml.classification import OneVsRestModelWriter
204+
205+
if shouldOverwrite:
206+
# TODO(SPARK-50954): Support client side model path overwrite
207+
warnings.warn("Overwrite doesn't take effect for OneVsRestModel")
208+
209+
writer = OneVsRestModelWriter(instance)
210+
writer.session(session)
211+
writer.save(path)
190212
else:
191213
raise NotImplementedError(f"Unsupported write for {instance.__class__}")
192214

@@ -215,6 +237,7 @@ def loadInstance(
215237
from pyspark.ml.evaluation import JavaEvaluator
216238
from pyspark.ml.pipeline import Pipeline, PipelineModel
217239
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
240+
from pyspark.ml.classification import OneVsRest, OneVsRestModel
218241

219242
if (
220243
issubclass(clazz, JavaModel)
@@ -307,5 +330,19 @@ def _get_class() -> Type[RL]:
307330
tvs_reader.session(session)
308331
return tvs_reader.load(path)
309332

333+
elif issubclass(clazz, OneVsRest):
334+
from pyspark.ml.classification import OneVsRestReader
335+
336+
ovr_reader = OneVsRestReader(OneVsRest)
337+
ovr_reader.session(session)
338+
return ovr_reader.load(path)
339+
340+
elif issubclass(clazz, OneVsRestModel):
341+
from pyspark.ml.classification import OneVsRestModelReader
342+
343+
ovr_reader = OneVsRestModelReader(OneVsRestModel)
344+
ovr_reader.session(session)
345+
return ovr_reader.load(path)
346+
310347
else:
311348
raise RuntimeError(f"Unsupported read for {clazz}")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
21+
from pyspark.testing.connectutils import ReusedConnectTestCase
22+
23+
24+
class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
25+
pass
26+
27+
28+
if __name__ == "__main__":
29+
from pyspark.ml.tests.connect.test_parity_ovr import * # noqa: F401
30+
31+
try:
32+
import xmlrunner # type: ignore[import]
33+
34+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
35+
except ImportError:
36+
testRunner = None
37+
unittest.main(testRunner=testRunner, verbosity=2)

python/pyspark/ml/tests/test_ovr.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import tempfile
20+
import unittest
21+
22+
import numpy as np
23+
24+
from pyspark.ml.linalg import Vectors
25+
from pyspark.ml.classification import (
26+
LinearSVC,
27+
LinearSVCModel,
28+
OneVsRest,
29+
OneVsRestModel,
30+
)
31+
from pyspark.testing.sqlutils import ReusedSQLTestCase
32+
33+
34+
class OneVsRestTestsMixin:
35+
def test_one_vs_rest(self):
36+
spark = self.spark
37+
df = (
38+
spark.createDataFrame(
39+
[
40+
(0, 1.0, Vectors.dense(0.0, 5.0)),
41+
(1, 0.0, Vectors.dense(1.0, 2.0)),
42+
(2, 1.0, Vectors.dense(2.0, 1.0)),
43+
(3, 2.0, Vectors.dense(3.0, 3.0)),
44+
],
45+
["index", "label", "features"],
46+
)
47+
.coalesce(1)
48+
.sortWithinPartitions("index")
49+
.select("label", "features")
50+
)
51+
52+
svc = LinearSVC(maxIter=1, regParam=1.0)
53+
self.assertEqual(svc.getMaxIter(), 1)
54+
self.assertEqual(svc.getRegParam(), 1.0)
55+
56+
ovr = OneVsRest(classifier=svc, parallelism=1)
57+
self.assertEqual(ovr.getParallelism(), 1)
58+
59+
model = ovr.fit(df)
60+
self.assertIsInstance(model, OneVsRestModel)
61+
self.assertEqual(len(model.models), 3)
62+
for submodel in model.models:
63+
self.assertIsInstance(submodel, LinearSVCModel)
64+
65+
self.assertTrue(
66+
np.allclose(model.models[0].intercept, 0.06279247869226989, atol=1e-4),
67+
model.models[0].intercept,
68+
)
69+
self.assertTrue(
70+
np.allclose(
71+
model.models[0].coefficients.toArray(),
72+
[-0.1198765502306968, -0.1027513287691687],
73+
atol=1e-4,
74+
),
75+
model.models[0].coefficients,
76+
)
77+
78+
self.assertTrue(
79+
np.allclose(model.models[1].intercept, 0.025877458475338313, atol=1e-4),
80+
model.models[1].intercept,
81+
)
82+
self.assertTrue(
83+
np.allclose(
84+
model.models[1].coefficients.toArray(),
85+
[-0.0362284418654736, 0.010350983390135305],
86+
atol=1e-4,
87+
),
88+
model.models[1].coefficients,
89+
)
90+
91+
self.assertTrue(
92+
np.allclose(model.models[2].intercept, -0.37024065419409624, atol=1e-4),
93+
model.models[2].intercept,
94+
)
95+
self.assertTrue(
96+
np.allclose(
97+
model.models[2].coefficients.toArray(),
98+
[0.12886829400126, 0.012273170857262873],
99+
atol=1e-4,
100+
),
101+
model.models[2].coefficients,
102+
)
103+
104+
output = model.transform(df)
105+
expected_cols = ["label", "features", "rawPrediction", "prediction"]
106+
self.assertEqual(output.columns, expected_cols)
107+
self.assertEqual(output.count(), 4)
108+
109+
# Model save & load
110+
with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
111+
path1 = os.path.join(d, "ovr")
112+
ovr.write().overwrite().save(path1)
113+
ovr2 = OneVsRest.load(path1)
114+
self.assertEqual(str(ovr), str(ovr2))
115+
116+
path2 = os.path.join(d, "ovr_model")
117+
model.write().overwrite().save(path2)
118+
model2 = OneVsRestModel.load(path2)
119+
self.assertEqual(str(model), str(model2))
120+
121+
122+
class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
123+
pass
124+
125+
126+
if __name__ == "__main__":
127+
from pyspark.ml.tests.test_ovr import * # noqa: F401,F403
128+
129+
try:
130+
import xmlrunner # type: ignore[import]
131+
132+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
133+
except ImportError:
134+
testRunner = None
135+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)