Skip to content

Commit fafe43c

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-49907][ML][CONNECT] Support spark.ml on Connect
### What changes were proposed in this pull request? This PR which derived from #40479 authored by WeichenXu123 enables running spark.ml on Connect. Currently, this PR supports the following functionalities: - Fit operation in LogisticRegression - Transform/predict operation in LogisticRegressionModel - Retrieving attributes in LogisticRegressionModel - Retrieving summary and its attributes in LogisticRegressionModel - Read/write operations for LogisticRegressionModel - Read/write operations for LogisticRegression - Evaluate a dataset in LogisticRegressionModel and return a summary and retrieve attributes from it. ### Why are the changes needed? It's a new feature that makes spark.ml run on connect environment. ### Does this PR introduce _any_ user-facing change? Yes, new feature. ### How was this patch tested? Make sure the CI (especially the newly added tests) pass. And we can manually run below code without any exception. ``` python (pyspark) userbobby:~ $ pyspark --remote sc://localhost Python 3.11.10 (main, Oct 3 2024, 07:29:13) [GCC 11.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 4.0.0.dev0 /_/ Using Python version 3.11.10 (main, Oct 3 2024 07:29:13) Client connected to the Spark Connect server at localhost SparkSession available as 'spark'. >>> from pyspark.ml.classification import (LogisticRegression, LogisticRegressionModel) >>> from pyspark.ml.linalg import Vectors >>> >>> df = spark.createDataFrame([ ... (Vectors.dense([1.0, 2.0]), 1), ... (Vectors.dense([2.0, -1.0]), 1), ... (Vectors.dense([-3.0, -2.0]), 0), ... (Vectors.dense([-1.0, -2.0]), 0), ... ], schema=['features', 'label']) >>> lr = LogisticRegression() >>> lr.setMaxIter(30) LogisticRegression_27d6a4e4f39d >>> lr.setThreshold(0.8) LogisticRegression_27d6a4e4f39d >>> lr.write().overwrite().save("/tmp/connect-ml-demo/estimator") >>> loaded_lr = LogisticRegression.load("/tmp/connect-ml-demo/estimator") >>> assert (loaded_lr.getThreshold() == 0.8) >>> assert loaded_lr.getMaxIter() == 30 >>> >>> model: LogisticRegressionModel = lr.fit(df) >>> assert (model.getThreshold() == 0.8) >>> assert model.getMaxIter() == 30 >>> model.predictRaw(Vectors.dense([1.0, 2.0])) DenseVector([-21.1048, 21.1048]) >>> model.summary.roc.show() +---+---+ |FPR|TPR| +---+---+ |0.0|0.0| |0.0|0.5| |0.0|1.0| |0.5|1.0| |1.0|1.0| |1.0|1.0| +---+---+ >>> model.summary.weightedRecall 1.0 >>> model.summary.recallByLabel [1.0, 1.0] >>> model.coefficients DenseVector([10.3964, 4.513]) >>> model.intercept 1.682348909633995 >>> model.transform(df).show() +-----------+-----+--------------------+--------------------+----------+ | features|label| rawPrediction| probability|prediction| +-----------+-----+--------------------+--------------------+----------+ | [1.0,2.0]| 1|[-21.104818251026...|[6.82800596289009...| 1.0| | [2.0,-1.0]| 1|[-17.962094978515...|[1.58183529116627...| 1.0| |[-3.0,-2.0]| 0|[38.5329050234205...| [1.0,0.0]| 0.0| |[-1.0,-2.0]| 0|[17.7401204317581...|[0.99999998025016...| 0.0| +-----------+-----+--------------------+--------------------+----------+ >>> model.write().overwrite().save("/tmp/connect-ml-demo/model") >>> loaded_model = LogisticRegressionModel.load("/tmp/connect-ml-demo/model") >>> assert loaded_model.getMaxIter() == 30 >>> loaded_model.transform(df).show() +-----------+-----+--------------------+--------------------+----------+ | features|label| rawPrediction| probability|prediction| +-----------+-----+--------------------+--------------------+----------+ | [1.0,2.0]| 1|[-21.104818251026...|[6.82800596289009...| 1.0| | [2.0,-1.0]| 1|[-17.962094978515...|[1.58183529116627...| 1.0| |[-3.0,-2.0]| 0|[38.5329050234205...| [1.0,0.0]| 0.0| |[-1.0,-2.0]| 0|[17.7401204317581...|[0.99999998025016...| 0.0| +-----------+-----+--------------------+--------------------+----------+ >>> >>> summary = loaded_model.evaluate(df) >>> summary.weightCol 'weightCol' >>> summary.recallByLabel [1.0, 1.0] >>> summary.accuracy 1.0 >>> summary.predictions.show() +-----------+-----+--------------------+--------------------+----------+ | features|label| rawPrediction| probability|prediction| +-----------+-----+--------------------+--------------------+----------+ | [1.0,2.0]| 1|[-21.104818251026...|[6.82800596289009...| 1.0| | [2.0,-1.0]| 1|[-17.962094978515...|[1.58183529116627...| 1.0| |[-3.0,-2.0]| 0|[38.5329050234205...| [1.0,0.0]| 0.0| |[-1.0,-2.0]| 0|[17.7401204317581...|[0.99999998025016...| 0.0| +-----------+-----+--------------------+--------------------+----------+ ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48791 from wbo4958/connect-ml. Authored-by: Bobby Wang <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent f96417f commit fafe43c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+4384
-640
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,24 @@
752752
},
753753
"sqlState" : "56K00"
754754
},
755+
"CONNECT_ML" : {
756+
"message" : [
757+
"Generic Spark Connect ML error."
758+
],
759+
"subClass" : {
760+
"ATTRIBUTE_NOT_ALLOWED" : {
761+
"message" : [
762+
"<attribute> is not allowed to be accessed."
763+
]
764+
},
765+
"UNSUPPORTED_EXCEPTION" : {
766+
"message" : [
767+
"<message>"
768+
]
769+
}
770+
},
771+
"sqlState" : "XX000"
772+
},
755773
"CONVERSION_INVALID_INPUT" : {
756774
"message" : [
757775
"The value <str> (<fmt>) cannot be converted to <targetType> because it is malformed. Correct the value as per the syntax, or change its format. Use <suggestion> to tolerate malformed input and return NULL instead."

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def __hash__(self):
686686
"pyspark.ml.tests.connect.test_legacy_mode_classification",
687687
"pyspark.ml.tests.connect.test_legacy_mode_pipeline",
688688
"pyspark.ml.tests.connect.test_legacy_mode_tuning",
689+
"pyspark.ml.tests.test_classification",
689690
],
690691
excluded_python_implementations=[
691692
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
@@ -1106,6 +1107,7 @@ def __hash__(self):
11061107
"pyspark.ml.tests.connect.test_connect_classification",
11071108
"pyspark.ml.tests.connect.test_connect_pipeline",
11081109
"pyspark.ml.tests.connect.test_connect_tuning",
1110+
"pyspark.ml.tests.connect.test_connect_spark_ml_classification",
11091111
],
11101112
excluded_python_implementations=[
11111113
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml estimators.
19+
# So register the supported estimator here if you're trying to add a new one.
20+
org.apache.spark.ml.classification.LogisticRegression
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer.
19+
# So register the supported transformer here if you're trying to add a new one.
20+
org.apache.spark.ml.feature.VectorAssembler

mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.Since
2121
import org.apache.spark.ml.linalg.Vector
22+
import org.apache.spark.ml.util.Summary
2223
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
2324
import org.apache.spark.sql.{DataFrame, Row}
2425
import org.apache.spark.sql.functions.{col, lit}
@@ -28,7 +29,7 @@ import org.apache.spark.sql.types.DoubleType
2829
/**
2930
* Abstraction for multiclass classification results for a given model.
3031
*/
31-
private[classification] trait ClassificationSummary extends Serializable {
32+
private[classification] trait ClassificationSummary extends Summary with Serializable {
3233

3334
/**
3435
* Dataframe output by the model's `transform` method.

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ package org.apache.spark.ml.param
1919

2020
import java.lang.reflect.Modifier
2121
import java.util.{List => JList}
22-
import java.util.NoSuchElementException
2322

2423
import scala.annotation.varargs
2524
import scala.collection.mutable
2625
import scala.jdk.CollectionConverters._
26+
import scala.reflect.ClassTag
2727

2828
import org.json4s._
2929
import org.json4s.jackson.JsonMethods._
@@ -45,9 +45,14 @@ import org.apache.spark.util.ArrayImplicits._
4545
* See [[ParamValidators]] for factory methods for common validation functions.
4646
* @tparam T param value type
4747
*/
48-
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
48+
class Param[T: ClassTag](
49+
val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
4950
extends Serializable {
5051

52+
// Spark Connect ML needs T type information which has been erased when compiling,
53+
// Use classTag to preserve the T type.
54+
val paramValueClassTag = implicitly[ClassTag[T]]
55+
5156
def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
5257
this(parent.uid, name, doc, isValid)
5358

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util
19+
20+
import org.apache.spark.annotation.Since
21+
22+
/**
23+
* Trait for the Summary
24+
* All the summaries should extend from this Summary in order to
25+
* support connect.
26+
*/
27+
@Since("4.0.0")
28+
private[spark] trait Summary

mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.List;
2222

2323
import org.apache.spark.ml.util.Identifiable$;
24+
import scala.reflect.ClassTag;
2425

2526
/**
2627
* A subclass of Params for testing.
@@ -110,7 +111,7 @@ private void init() {
110111
ParamValidators.inRange(0.0, 1.0));
111112
List<String> validStrings = Arrays.asList("a", "b");
112113
myStringParam_ = new Param<>(this, "myStringParam", "this is a string param",
113-
ParamValidators.inArray(validStrings));
114+
ParamValidators.inArray(validStrings), ClassTag.apply(String.class));
114115
myDoubleArrayParam_ =
115116
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
116117

python/pyspark/ml/classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
HasSolver,
6363
HasParallelism,
6464
)
65+
from pyspark.ml.remote.util import try_remote_attribute_relation
6566
from pyspark.ml.tree import (
6667
_DecisionTreeModel,
6768
_DecisionTreeParams,
@@ -336,6 +337,7 @@ class _ClassificationSummary(JavaWrapper):
336337

337338
@property
338339
@since("3.1.0")
340+
@try_remote_attribute_relation
339341
def predictions(self) -> DataFrame:
340342
"""
341343
Dataframe outputted by the model's `transform` method.
@@ -521,6 +523,7 @@ def scoreCol(self) -> str:
521523
return self._call_java("scoreCol")
522524

523525
@property
526+
@try_remote_attribute_relation
524527
def roc(self) -> DataFrame:
525528
"""
526529
Returns the receiver operating characteristic (ROC) curve,
@@ -546,6 +549,7 @@ def areaUnderROC(self) -> float:
546549

547550
@property
548551
@since("3.1.0")
552+
@try_remote_attribute_relation
549553
def pr(self) -> DataFrame:
550554
"""
551555
Returns the precision-recall curve, which is a Dataframe
@@ -556,6 +560,7 @@ def pr(self) -> DataFrame:
556560

557561
@property
558562
@since("3.1.0")
563+
@try_remote_attribute_relation
559564
def fMeasureByThreshold(self) -> DataFrame:
560565
"""
561566
Returns a dataframe with two fields (threshold, F-Measure) curve
@@ -565,6 +570,7 @@ def fMeasureByThreshold(self) -> DataFrame:
565570

566571
@property
567572
@since("3.1.0")
573+
@try_remote_attribute_relation
568574
def precisionByThreshold(self) -> DataFrame:
569575
"""
570576
Returns a dataframe with two fields (threshold, precision) curve.
@@ -575,6 +581,7 @@ def precisionByThreshold(self) -> DataFrame:
575581

576582
@property
577583
@since("3.1.0")
584+
@try_remote_attribute_relation
578585
def recallByThreshold(self) -> DataFrame:
579586
"""
580587
Returns a dataframe with two fields (threshold, recall) curve.

python/pyspark/ml/remote/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#

python/pyspark/ml/remote/proto.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from typing import Optional, TYPE_CHECKING, List
18+
19+
import pyspark.sql.connect.proto as pb2
20+
from pyspark.sql.connect.plan import LogicalPlan
21+
22+
if TYPE_CHECKING:
23+
from pyspark.sql.connect.client import SparkConnectClient
24+
25+
26+
class TransformerRelation(LogicalPlan):
27+
"""A logical plan for transforming of a transformer which could be a cached model
28+
or a non-model transformer like VectorAssembler."""
29+
30+
def __init__(
31+
self,
32+
child: Optional["LogicalPlan"],
33+
name: str,
34+
ml_params: pb2.MlParams,
35+
uid: str = "",
36+
is_model: bool = True,
37+
) -> None:
38+
super().__init__(child)
39+
self._name = name
40+
self._ml_params = ml_params
41+
self._uid = uid
42+
self._is_model = is_model
43+
44+
def plan(self, session: "SparkConnectClient") -> pb2.Relation:
45+
assert self._child is not None
46+
plan = self._create_proto_relation()
47+
plan.ml_relation.transform.input.CopyFrom(self._child.plan(session))
48+
49+
if self._is_model:
50+
plan.ml_relation.transform.obj_ref.CopyFrom(pb2.ObjectRef(id=self._name))
51+
else:
52+
plan.ml_relation.transform.transformer.CopyFrom(
53+
pb2.MlOperator(name=self._name, uid=self._uid, type=pb2.MlOperator.TRANSFORMER)
54+
)
55+
56+
if self._ml_params is not None:
57+
plan.ml_relation.transform.params.CopyFrom(self._ml_params)
58+
59+
return plan
60+
61+
62+
class AttributeRelation(LogicalPlan):
63+
"""A logical plan used in ML to represent an attribute of an instance, which
64+
could be a model or a summary. This attribute returns a DataFrame.
65+
"""
66+
67+
def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None:
68+
super().__init__(None)
69+
self._ref_id = ref_id
70+
self._methods = methods
71+
72+
def plan(self, session: "SparkConnectClient") -> pb2.Relation:
73+
plan = self._create_proto_relation()
74+
plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id))
75+
plan.ml_relation.fetch.methods.extend(self._methods)
76+
return plan

0 commit comments

Comments
 (0)