Skip to content

Commit

Permalink
[SPARK-49907][ML][CONNECT] Support spark.ml on Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR  which derived from #40479 authored by WeichenXu123 enables running spark.ml on Connect. Currently, this PR supports the following functionalities:

- Fit operation in LogisticRegression
- Transform/predict operation in LogisticRegressionModel
- Retrieving attributes in LogisticRegressionModel
- Retrieving summary and its attributes in LogisticRegressionModel
- Read/write operations for LogisticRegressionModel
- Read/write operations for LogisticRegression
- Evaluate a dataset in LogisticRegressionModel and return a summary and retrieve attributes from it.

### Why are the changes needed?

It's a new feature that makes spark.ml run on connect environment.

### Does this PR introduce _any_ user-facing change?

Yes, new feature.

### How was this patch tested?

Make sure the CI (especially the newly added tests) pass. And we can manually run below code without any exception.

``` python
(pyspark) userbobby:~ $ pyspark --remote sc://localhost
Python 3.11.10 (main, Oct  3 2024, 07:29:13) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 4.0.0.dev0
      /_/

Using Python version 3.11.10 (main, Oct  3 2024 07:29:13)
Client connected to the Spark Connect server at localhost
SparkSession available as 'spark'.

>>> from pyspark.ml.classification import (LogisticRegression, LogisticRegressionModel)
>>> from pyspark.ml.linalg import Vectors
>>>
>>> df = spark.createDataFrame([
...     (Vectors.dense([1.0, 2.0]), 1),
...     (Vectors.dense([2.0, -1.0]), 1),
...     (Vectors.dense([-3.0, -2.0]), 0),
...     (Vectors.dense([-1.0, -2.0]), 0),
... ], schema=['features', 'label'])
>>> lr = LogisticRegression()
>>> lr.setMaxIter(30)
LogisticRegression_27d6a4e4f39d
>>> lr.setThreshold(0.8)
LogisticRegression_27d6a4e4f39d
>>> lr.write().overwrite().save("/tmp/connect-ml-demo/estimator")
>>> loaded_lr = LogisticRegression.load("/tmp/connect-ml-demo/estimator")
>>> assert (loaded_lr.getThreshold() == 0.8)
>>> assert loaded_lr.getMaxIter() == 30
>>>
>>> model: LogisticRegressionModel = lr.fit(df)
>>> assert (model.getThreshold() == 0.8)
>>> assert model.getMaxIter() == 30
>>> model.predictRaw(Vectors.dense([1.0, 2.0]))
DenseVector([-21.1048, 21.1048])
>>> model.summary.roc.show()
+---+---+
|FPR|TPR|
+---+---+
|0.0|0.0|
|0.0|0.5|
|0.0|1.0|
|0.5|1.0|
|1.0|1.0|
|1.0|1.0|
+---+---+

>>> model.summary.weightedRecall
1.0
>>> model.summary.recallByLabel
[1.0, 1.0]
>>> model.coefficients
DenseVector([10.3964, 4.513])
>>> model.intercept
1.682348909633995
>>> model.transform(df).show()
+-----------+-----+--------------------+--------------------+----------+
|   features|label|       rawPrediction|         probability|prediction|
+-----------+-----+--------------------+--------------------+----------+
|  [1.0,2.0]|    1|[-21.104818251026...|[6.82800596289009...|       1.0|
| [2.0,-1.0]|    1|[-17.962094978515...|[1.58183529116627...|       1.0|
|[-3.0,-2.0]|    0|[38.5329050234205...|           [1.0,0.0]|       0.0|
|[-1.0,-2.0]|    0|[17.7401204317581...|[0.99999998025016...|       0.0|
+-----------+-----+--------------------+--------------------+----------+

>>> model.write().overwrite().save("/tmp/connect-ml-demo/model")
>>> loaded_model = LogisticRegressionModel.load("/tmp/connect-ml-demo/model")
>>> assert loaded_model.getMaxIter() == 30
>>> loaded_model.transform(df).show()
+-----------+-----+--------------------+--------------------+----------+
|   features|label|       rawPrediction|         probability|prediction|
+-----------+-----+--------------------+--------------------+----------+
|  [1.0,2.0]|    1|[-21.104818251026...|[6.82800596289009...|       1.0|
| [2.0,-1.0]|    1|[-17.962094978515...|[1.58183529116627...|       1.0|
|[-3.0,-2.0]|    0|[38.5329050234205...|           [1.0,0.0]|       0.0|
|[-1.0,-2.0]|    0|[17.7401204317581...|[0.99999998025016...|       0.0|
+-----------+-----+--------------------+--------------------+----------+

>>>
>>> summary = loaded_model.evaluate(df)
>>> summary.weightCol
'weightCol'
>>> summary.recallByLabel
[1.0, 1.0]
>>> summary.accuracy
1.0
>>> summary.predictions.show()
+-----------+-----+--------------------+--------------------+----------+
|   features|label|       rawPrediction|         probability|prediction|
+-----------+-----+--------------------+--------------------+----------+
|  [1.0,2.0]|    1|[-21.104818251026...|[6.82800596289009...|       1.0|
| [2.0,-1.0]|    1|[-17.962094978515...|[1.58183529116627...|       1.0|
|[-3.0,-2.0]|    0|[38.5329050234205...|           [1.0,0.0]|       0.0|
|[-1.0,-2.0]|    0|[17.7401204317581...|[0.99999998025016...|       0.0|
+-----------+-----+--------------------+--------------------+----------+

```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #48791 from wbo4958/connect-ml.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Jan 14, 2025
1 parent f96417f commit fafe43c
Show file tree
Hide file tree
Showing 47 changed files with 4,384 additions and 640 deletions.
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,24 @@
},
"sqlState" : "56K00"
},
"CONNECT_ML" : {
"message" : [
"Generic Spark Connect ML error."
],
"subClass" : {
"ATTRIBUTE_NOT_ALLOWED" : {
"message" : [
"<attribute> is not allowed to be accessed."
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
]
}
},
"sqlState" : "XX000"
},
"CONVERSION_INVALID_INPUT" : {
"message" : [
"The value <str> (<fmt>) cannot be converted to <targetType> because it is malformed. Correct the value as per the syntax, or change its format. Use <suggestion> to tolerate malformed input and return NULL instead."
Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_legacy_mode_classification",
"pyspark.ml.tests.connect.test_legacy_mode_pipeline",
"pyspark.ml.tests.connect.test_legacy_mode_tuning",
"pyspark.ml.tests.test_classification",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_connect_classification",
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_connect_spark_ml_classification",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml estimators.
# So register the supported estimator here if you're trying to add a new one.
org.apache.spark.ml.classification.LogisticRegression
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
org.apache.spark.ml.feature.VectorAssembler
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Summary
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.sql.types.DoubleType
/**
* Abstraction for multiclass classification results for a given model.
*/
private[classification] trait ClassificationSummary extends Serializable {
private[classification] trait ClassificationSummary extends Summary with Serializable {

/**
* Dataframe output by the model's `transform` method.
Expand Down
9 changes: 7 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.ml.param

import java.lang.reflect.Modifier
import java.util.{List => JList}
import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import org.json4s._
import org.json4s.jackson.JsonMethods._
Expand All @@ -45,9 +45,14 @@ import org.apache.spark.util.ArrayImplicits._
* See [[ParamValidators]] for factory methods for common validation functions.
* @tparam T param value type
*/
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
class Param[T: ClassTag](
val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
extends Serializable {

// Spark Connect ML needs T type information which has been erased when compiling,
// Use classTag to preserve the T type.
val paramValueClassTag = implicitly[ClassTag[T]]

def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
this(parent.uid, name, doc, isValid)

Expand Down
28 changes: 28 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/Summary.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.util

import org.apache.spark.annotation.Since

/**
* Trait for the Summary
* All the summaries should extend from this Summary in order to
* support connect.
*/
@Since("4.0.0")
private[spark] trait Summary
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;

import org.apache.spark.ml.util.Identifiable$;
import scala.reflect.ClassTag;

/**
* A subclass of Params for testing.
Expand Down Expand Up @@ -110,7 +111,7 @@ private void init() {
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Arrays.asList("a", "b");
myStringParam_ = new Param<>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
ParamValidators.inArray(validStrings), ClassTag.apply(String.class));
myDoubleArrayParam_ =
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
HasSolver,
HasParallelism,
)
from pyspark.ml.remote.util import try_remote_attribute_relation
from pyspark.ml.tree import (
_DecisionTreeModel,
_DecisionTreeParams,
Expand Down Expand Up @@ -336,6 +337,7 @@ class _ClassificationSummary(JavaWrapper):

@property
@since("3.1.0")
@try_remote_attribute_relation
def predictions(self) -> DataFrame:
"""
Dataframe outputted by the model's `transform` method.
Expand Down Expand Up @@ -521,6 +523,7 @@ def scoreCol(self) -> str:
return self._call_java("scoreCol")

@property
@try_remote_attribute_relation
def roc(self) -> DataFrame:
"""
Returns the receiver operating characteristic (ROC) curve,
Expand All @@ -546,6 +549,7 @@ def areaUnderROC(self) -> float:

@property
@since("3.1.0")
@try_remote_attribute_relation
def pr(self) -> DataFrame:
"""
Returns the precision-recall curve, which is a Dataframe
Expand All @@ -556,6 +560,7 @@ def pr(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def fMeasureByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, F-Measure) curve
Expand All @@ -565,6 +570,7 @@ def fMeasureByThreshold(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def precisionByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, precision) curve.
Expand All @@ -575,6 +581,7 @@ def precisionByThreshold(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def recallByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, recall) curve.
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/ml/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
76 changes: 76 additions & 0 deletions python/pyspark/ml/remote/proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, TYPE_CHECKING, List

import pyspark.sql.connect.proto as pb2
from pyspark.sql.connect.plan import LogicalPlan

if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient


class TransformerRelation(LogicalPlan):
"""A logical plan for transforming of a transformer which could be a cached model
or a non-model transformer like VectorAssembler."""

def __init__(
self,
child: Optional["LogicalPlan"],
name: str,
ml_params: pb2.MlParams,
uid: str = "",
is_model: bool = True,
) -> None:
super().__init__(child)
self._name = name
self._ml_params = ml_params
self._uid = uid
self._is_model = is_model

def plan(self, session: "SparkConnectClient") -> pb2.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.ml_relation.transform.input.CopyFrom(self._child.plan(session))

if self._is_model:
plan.ml_relation.transform.obj_ref.CopyFrom(pb2.ObjectRef(id=self._name))
else:
plan.ml_relation.transform.transformer.CopyFrom(
pb2.MlOperator(name=self._name, uid=self._uid, type=pb2.MlOperator.TRANSFORMER)
)

if self._ml_params is not None:
plan.ml_relation.transform.params.CopyFrom(self._ml_params)

return plan


class AttributeRelation(LogicalPlan):
"""A logical plan used in ML to represent an attribute of an instance, which
could be a model or a summary. This attribute returns a DataFrame.
"""

def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None:
super().__init__(None)
self._ref_id = ref_id
self._methods = methods

def plan(self, session: "SparkConnectClient") -> pb2.Relation:
plan = self._create_proto_relation()
plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id))
plan.ml_relation.fetch.methods.extend(self._methods)
return plan
Loading

0 comments on commit fafe43c

Please sign in to comment.