Skip to content

Commit 0c283a8

Browse files
committed
Revert "[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect"
This reverts commit 22bac2e.
1 parent 4e5a813 commit 0c283a8

File tree

5 files changed

+19
-251
lines changed

5 files changed

+19
-251
lines changed

dev/sparktestsupport/modules.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,6 @@ def __hash__(self):
676676
"pyspark.ml.tests.test_persistence",
677677
"pyspark.ml.tests.test_pipeline",
678678
"pyspark.ml.tests.test_tuning",
679-
"pyspark.ml.tests.test_ovr",
680679
"pyspark.ml.tests.test_stat",
681680
"pyspark.ml.tests.test_training_summary",
682681
"pyspark.ml.tests.tuning.test_tuning",
@@ -1130,7 +1129,6 @@ def __hash__(self):
11301129
"pyspark.ml.tests.connect.test_parity_feature",
11311130
"pyspark.ml.tests.connect.test_parity_pipeline",
11321131
"pyspark.ml.tests.connect.test_parity_tuning",
1133-
"pyspark.ml.tests.connect.test_parity_ovr",
11341132
],
11351133
excluded_python_implementations=[
11361134
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and

python/pyspark/ml/classification.py

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@
8585
MLWriter,
8686
MLWritable,
8787
HasTrainingSummary,
88-
try_remote_read,
89-
try_remote_write,
9088
try_remote_attribute_relation,
9189
)
9290
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
@@ -96,7 +94,6 @@
9694
from pyspark.sql.functions import udf, when
9795
from pyspark.sql.types import ArrayType, DoubleType
9896
from pyspark.storagelevel import StorageLevel
99-
from pyspark.sql.utils import is_remote
10097

10198
if TYPE_CHECKING:
10299
from pyspark.ml._typing import P, ParamMap
@@ -3575,45 +3572,31 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
35753572
if handlePersistence:
35763573
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
35773574

3578-
def _oneClassFitTasks(numClasses: int):
3579-
indices = iter(range(numClasses))
3580-
3581-
def trainSingleClass() -> CM:
3582-
index = next(indices)
3583-
3584-
binaryLabelCol = "mc2b$" + str(index)
3585-
trainingDataset = multiclassLabeled.withColumn(
3586-
binaryLabelCol,
3587-
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3588-
)
3589-
paramMap = dict(
3590-
[
3591-
(classifier.labelCol, binaryLabelCol),
3592-
(classifier.featuresCol, featuresCol),
3593-
(classifier.predictionCol, predictionCol),
3594-
]
3595-
)
3596-
if weightCol:
3597-
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3598-
return index, classifier.fit(trainingDataset, paramMap)
3599-
3600-
return [trainSingleClass] * numClasses
3575+
def trainSingleClass(index: int) -> CM:
3576+
binaryLabelCol = "mc2b$" + str(index)
3577+
trainingDataset = multiclassLabeled.withColumn(
3578+
binaryLabelCol,
3579+
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3580+
)
3581+
paramMap = dict(
3582+
[
3583+
(classifier.labelCol, binaryLabelCol),
3584+
(classifier.featuresCol, featuresCol),
3585+
(classifier.predictionCol, predictionCol),
3586+
]
3587+
)
3588+
if weightCol:
3589+
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3590+
return classifier.fit(trainingDataset, paramMap)
36013591

3602-
tasks = map(
3603-
inheritable_thread_target(dataset.sparkSession),
3604-
_oneClassFitTasks(numClasses),
3605-
)
36063592
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
36073593

3608-
subModels = [None] * numClasses
3609-
for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
3610-
assert subModels is not None
3611-
subModels[j] = subModel
3594+
models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
36123595

36133596
if handlePersistence:
36143597
multiclassLabeled.unpersist()
36153598

3616-
return self._copyValues(OneVsRestModel(models=subModels))
3599+
return self._copyValues(OneVsRestModel(models=models))
36173600

36183601
def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
36193602
"""
@@ -3688,11 +3671,9 @@ def _to_java(self) -> "JavaObject":
36883671
return _java_obj
36893672

36903673
@classmethod
3691-
@try_remote_read
36923674
def read(cls) -> "OneVsRestReader":
36933675
return OneVsRestReader(cls)
36943676

3695-
@try_remote_write
36963677
def write(self) -> MLWriter:
36973678
if isinstance(self.getClassifier(), JavaMLWritable):
36983679
return JavaMLWriter(self) # type: ignore[arg-type]
@@ -3806,7 +3787,7 @@ def __init__(self, models: List[ClassificationModel]):
38063787
from pyspark.core.context import SparkContext
38073788

38083789
self.models = models
3809-
if is_remote() or not isinstance(models[0], JavaMLWritable):
3790+
if not isinstance(models[0], JavaMLWritable):
38103791
return
38113792
# set java instance
38123793
java_models = [cast(_JavaClassificationModel, model)._to_java() for model in self.models]
@@ -3974,11 +3955,9 @@ def _to_java(self) -> "JavaObject":
39743955
return _java_obj
39753956

39763957
@classmethod
3977-
@try_remote_read
39783958
def read(cls) -> "OneVsRestModelReader":
39793959
return OneVsRestModelReader(cls)
39803960

3981-
@try_remote_write
39823961
def write(self) -> MLWriter:
39833962
if all(
39843963
map(

python/pyspark/ml/connect/readwrite.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def saveInstance(
9595
from pyspark.ml.evaluation import JavaEvaluator
9696
from pyspark.ml.pipeline import Pipeline, PipelineModel
9797
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
98-
from pyspark.ml.classification import OneVsRest, OneVsRestModel
9998

10099
# Spark Connect ML is built on scala Spark.ML, that means we're only
101100
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -188,27 +187,6 @@ def saveInstance(
188187
warnings.warn("Overwrite doesn't take effect for TrainValidationSplitModel")
189188
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, optionMap, session)
190189
tvsm_writer.save(path)
191-
elif isinstance(instance, OneVsRest):
192-
from pyspark.ml.classification import OneVsRestWriter
193-
194-
if shouldOverwrite:
195-
# TODO(SPARK-50954): Support client side model path overwrite
196-
warnings.warn("Overwrite doesn't take effect for OneVsRest")
197-
198-
writer = OneVsRestWriter(instance)
199-
writer.session(session)
200-
writer.save(path)
201-
# _OneVsRestSharedReadWrite.saveImpl(self.instance, self.sparkSession, path)
202-
elif isinstance(instance, OneVsRestModel):
203-
from pyspark.ml.classification import OneVsRestModelWriter
204-
205-
if shouldOverwrite:
206-
# TODO(SPARK-50954): Support client side model path overwrite
207-
warnings.warn("Overwrite doesn't take effect for OneVsRestModel")
208-
209-
writer = OneVsRestModelWriter(instance)
210-
writer.session(session)
211-
writer.save(path)
212190
else:
213191
raise NotImplementedError(f"Unsupported write for {instance.__class__}")
214192

@@ -237,7 +215,6 @@ def loadInstance(
237215
from pyspark.ml.evaluation import JavaEvaluator
238216
from pyspark.ml.pipeline import Pipeline, PipelineModel
239217
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
240-
from pyspark.ml.classification import OneVsRest, OneVsRestModel
241218

242219
if (
243220
issubclass(clazz, JavaModel)
@@ -330,19 +307,5 @@ def _get_class() -> Type[RL]:
330307
tvs_reader.session(session)
331308
return tvs_reader.load(path)
332309

333-
elif issubclass(clazz, OneVsRest):
334-
from pyspark.ml.classification import OneVsRestReader
335-
336-
ovr_reader = OneVsRestReader(OneVsRest)
337-
ovr_reader.session(session)
338-
return ovr_reader.load(path)
339-
340-
elif issubclass(clazz, OneVsRestModel):
341-
from pyspark.ml.classification import OneVsRestModelReader
342-
343-
ovr_reader = OneVsRestModelReader(OneVsRestModel)
344-
ovr_reader.session(session)
345-
return ovr_reader.load(path)
346-
347310
else:
348311
raise RuntimeError(f"Unsupported read for {clazz}")

python/pyspark/ml/tests/connect/test_parity_ovr.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

python/pyspark/ml/tests/test_ovr.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

0 commit comments

Comments
 (0)