Skip to content

Commit 93e56c2

Browse files
committed
[SPARK-50874][ML][PYTHON][CONNECT] Support LinearRegression on connect
### What changes were proposed in this pull request? Support LinearRegression on connect ### Why are the changes needed? feature parity for connect ### Does this PR introduce _any_ user-facing change? yes, new feature ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49553 from zhengruifeng/ml_regression. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent ef363b6 commit 93e56c2

File tree

8 files changed

+255
-2
lines changed

8 files changed

+255
-2
lines changed

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ def __hash__(self):
696696
"pyspark.ml.tests.connect.test_legacy_mode_pipeline",
697697
"pyspark.ml.tests.connect.test_legacy_mode_tuning",
698698
"pyspark.ml.tests.test_classification",
699+
"pyspark.ml.tests.test_regression",
699700
],
700701
excluded_python_implementations=[
701702
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
@@ -1117,6 +1118,7 @@ def __hash__(self):
11171118
"pyspark.ml.tests.connect.test_connect_pipeline",
11181119
"pyspark.ml.tests.connect.test_connect_tuning",
11191120
"pyspark.ml.tests.connect.test_parity_classification",
1121+
"pyspark.ml.tests.connect.test_parity_regression",
11201122
"pyspark.ml.tests.connect.test_parity_evaluation",
11211123
],
11221124
excluded_python_implementations=[

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,7 @@ org.apache.spark.ml.classification.LogisticRegression
2323
org.apache.spark.ml.classification.DecisionTreeClassifier
2424
org.apache.spark.ml.classification.RandomForestClassifier
2525
org.apache.spark.ml.classification.GBTClassifier
26+
27+
28+
# regression
29+
org.apache.spark.ml.regression.LinearRegression

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ class LinearRegressionSummary private[regression] (
903903
val labelCol: String,
904904
val featuresCol: String,
905905
private val privateModel: LinearRegressionModel,
906-
private val diagInvAtWA: Array[Double]) extends Serializable {
906+
private val diagInvAtWA: Array[Double]) extends Summary with Serializable {
907907

908908
@transient private val metrics = {
909909
val weightCol =

python/pyspark/ml/regression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
HasLoss,
4545
HasVarianceCol,
4646
)
47+
from pyspark.ml.util import try_remote_attribute_relation
4748
from pyspark.ml.tree import (
4849
_DecisionTreeModel,
4950
_DecisionTreeParams,
@@ -517,6 +518,7 @@ class LinearRegressionSummary(JavaWrapper):
517518

518519
@property
519520
@since("2.0.0")
521+
@try_remote_attribute_relation
520522
def predictions(self) -> DataFrame:
521523
"""
522524
Dataframe outputted by the model's `transform` method.
@@ -651,6 +653,7 @@ def r2adj(self) -> float:
651653

652654
@property
653655
@since("2.0.0")
656+
@try_remote_attribute_relation
654657
def residuals(self) -> DataFrame:
655658
"""
656659
Residuals (label - predicted value)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import unittest
20+
21+
from pyspark.ml.tests.test_regression import RegressionTestsMixin
22+
from pyspark.sql import SparkSession
23+
24+
25+
class RegressionParityTests(RegressionTestsMixin, unittest.TestCase):
26+
def setUp(self) -> None:
27+
self.spark = SparkSession.builder.remote(
28+
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
29+
).getOrCreate()
30+
31+
def test_assert_remote_mode(self):
32+
from pyspark.sql import is_remote
33+
34+
self.assertTrue(is_remote())
35+
36+
def tearDown(self) -> None:
37+
self.spark.stop()
38+
39+
40+
if __name__ == "__main__":
41+
from pyspark.ml.tests.connect.test_parity_regression import * # noqa: F401
42+
43+
try:
44+
import xmlrunner # type: ignore[import]
45+
46+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
47+
except ImportError:
48+
testRunner = None
49+
unittest.main(testRunner=testRunner, verbosity=2)
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import tempfile
19+
import unittest
20+
21+
import numpy as np
22+
23+
from pyspark.ml.linalg import Vectors
24+
from pyspark.sql import SparkSession
25+
from pyspark.ml.regression import (
26+
LinearRegression,
27+
LinearRegressionModel,
28+
LinearRegressionSummary,
29+
LinearRegressionTrainingSummary,
30+
)
31+
32+
33+
class RegressionTestsMixin:
34+
@property
35+
def df(self):
36+
return (
37+
self.spark.createDataFrame(
38+
[
39+
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
40+
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
41+
(1.5, 3.0, Vectors.dense(2.0, 1.0)),
42+
(0.7, 4.0, Vectors.dense(1.5, 3.0)),
43+
],
44+
["label", "weight", "features"],
45+
)
46+
.coalesce(1)
47+
.sortWithinPartitions("weight")
48+
)
49+
50+
def test_linear_regression(self):
51+
df = self.df
52+
lr = LinearRegression(
53+
regParam=0.0,
54+
maxIter=2,
55+
solver="normal",
56+
weightCol="weight",
57+
)
58+
self.assertEqual(lr.getRegParam(), 0)
59+
self.assertEqual(lr.getMaxIter(), 2)
60+
self.assertEqual(lr.getSolver(), "normal")
61+
self.assertEqual(lr.getWeightCol(), "weight")
62+
63+
# Estimator save & load
64+
with tempfile.TemporaryDirectory(prefix="linear_regression") as d:
65+
lr.write().overwrite().save(d)
66+
lr2 = LinearRegression.load(d)
67+
self.assertEqual(str(lr), str(lr2))
68+
69+
model = lr.fit(df)
70+
self.assertEqual(model.numFeatures, 2)
71+
self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4))
72+
self.assertTrue(np.allclose(model.intercept, -0.35, atol=1e-4))
73+
self.assertTrue(np.allclose(model.coefficients, [0.65, 0.1125], atol=1e-4))
74+
75+
output = model.transform(df)
76+
expected_cols = [
77+
"label",
78+
"weight",
79+
"features",
80+
"prediction",
81+
]
82+
self.assertEqual(output.columns, expected_cols)
83+
self.assertEqual(output.count(), 4)
84+
85+
self.assertTrue(
86+
np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 0.21249999999999963, atol=1e-4)
87+
)
88+
89+
# Model summary
90+
summary = model.summary
91+
self.assertTrue(isinstance(summary, LinearRegressionSummary))
92+
self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
93+
self.assertEqual(summary.predictions.columns, expected_cols)
94+
self.assertEqual(summary.predictions.count(), 4)
95+
self.assertEqual(summary.residuals.columns, ["residuals"])
96+
self.assertEqual(summary.residuals.count(), 4)
97+
98+
self.assertEqual(summary.degreesOfFreedom, 1)
99+
self.assertEqual(summary.numInstances, 4)
100+
self.assertEqual(summary.objectiveHistory, [0.0])
101+
self.assertTrue(
102+
np.allclose(
103+
summary.coefficientStandardErrors,
104+
[1.2859821149611763, 0.6248749874975031, 3.1645497310044184],
105+
atol=1e-4,
106+
)
107+
)
108+
self.assertTrue(
109+
np.allclose(
110+
summary.devianceResiduals, [-0.7424621202458727, 0.7875000000000003], atol=1e-4
111+
)
112+
)
113+
self.assertTrue(
114+
np.allclose(
115+
summary.pValues,
116+
[0.7020630236843428, 0.8866003086182783, 0.9298746994547682],
117+
atol=1e-4,
118+
)
119+
)
120+
self.assertTrue(
121+
np.allclose(
122+
summary.tValues,
123+
[0.5054502643838291, 0.1800360108036021, -0.11060025272186746],
124+
atol=1e-4,
125+
)
126+
)
127+
self.assertTrue(np.allclose(summary.explainedVariance, 0.07997500000000031, atol=1e-4))
128+
self.assertTrue(np.allclose(summary.meanAbsoluteError, 0.4200000000000002, atol=1e-4))
129+
self.assertTrue(np.allclose(summary.meanSquaredError, 0.20212500000000005, atol=1e-4))
130+
self.assertTrue(np.allclose(summary.rootMeanSquaredError, 0.44958314025327956, atol=1e-4))
131+
self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4))
132+
self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414, atol=1e-4))
133+
134+
summary2 = model.evaluate(df)
135+
self.assertTrue(isinstance(summary2, LinearRegressionSummary))
136+
self.assertFalse(isinstance(summary2, LinearRegressionTrainingSummary))
137+
self.assertEqual(summary2.predictions.columns, expected_cols)
138+
self.assertEqual(summary2.predictions.count(), 4)
139+
self.assertEqual(summary2.residuals.columns, ["residuals"])
140+
self.assertEqual(summary2.residuals.count(), 4)
141+
142+
self.assertEqual(summary2.degreesOfFreedom, 1)
143+
self.assertEqual(summary2.numInstances, 4)
144+
self.assertTrue(
145+
np.allclose(
146+
summary2.devianceResiduals, [-0.7424621202458727, 0.7875000000000003], atol=1e-4
147+
)
148+
)
149+
self.assertTrue(np.allclose(summary2.explainedVariance, 0.07997500000000031, atol=1e-4))
150+
self.assertTrue(np.allclose(summary2.meanAbsoluteError, 0.4200000000000002, atol=1e-4))
151+
self.assertTrue(np.allclose(summary2.meanSquaredError, 0.20212500000000005, atol=1e-4))
152+
self.assertTrue(np.allclose(summary2.rootMeanSquaredError, 0.44958314025327956, atol=1e-4))
153+
self.assertTrue(np.allclose(summary2.r2, 0.4427212572373862, atol=1e-4))
154+
self.assertTrue(np.allclose(summary2.r2adj, -0.6718362282878414, atol=1e-4))
155+
156+
# Model save & load
157+
with tempfile.TemporaryDirectory(prefix="linear_regression_model") as d:
158+
model.write().overwrite().save(d)
159+
model2 = LinearRegressionModel.load(d)
160+
self.assertEqual(str(model), str(model2))
161+
162+
163+
class RegressionTests(RegressionTestsMixin, unittest.TestCase):
164+
def setUp(self) -> None:
165+
self.spark = SparkSession.builder.master("local[4]").getOrCreate()
166+
167+
def tearDown(self) -> None:
168+
self.spark.stop()
169+
170+
171+
if __name__ == "__main__":
172+
from pyspark.ml.tests.test_regression import * # noqa: F401,F403
173+
174+
try:
175+
import xmlrunner # type: ignore[import]
176+
177+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
178+
except ImportError:
179+
testRunner = None
180+
unittest.main(testRunner=testRunner, verbosity=2)

python/pyspark/ml/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ class GeneralJavaMLWritable(JavaMLWritable):
580580
(Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
581581
"""
582582

583+
@try_remote_write
583584
def write(self) -> GeneralJavaMLWriter:
584585
"""Returns an GeneralMLWriter instance for this ML instance."""
585586
return GeneralJavaMLWriter(self)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ private[ml] object MLUtils {
394394
"featureImportances", // Tree models
395395
"predictRaw", // ClassificationModel
396396
"predictProbability", // ProbabilisticClassificationModel
397+
"scale", // LinearRegressionModel
397398
"coefficients",
398399
"intercept",
399400
"coefficientMatrix",
@@ -428,7 +429,20 @@ private[ml] object MLUtils {
428429
"probabilityCol",
429430
"featuresCol", // LogisticRegressionSummary
430431
"objectiveHistory",
431-
"totalIterations" // _TrainingSummary
432+
"coefficientStandardErrors", // _TrainingSummary
433+
"degreesOfFreedom", // LinearRegressionSummary
434+
"devianceResiduals", // LinearRegressionSummary
435+
"explainedVariance", // LinearRegressionSummary
436+
"meanAbsoluteError", // LinearRegressionSummary
437+
"meanSquaredError", // LinearRegressionSummary
438+
"numInstances", // LinearRegressionSummary
439+
"pValues", // LinearRegressionSummary
440+
"r2", // LinearRegressionSummary
441+
"r2adj", // LinearRegressionSummary
442+
"residuals", // LinearRegressionSummary
443+
"rootMeanSquaredError", // LinearRegressionSummary
444+
"tValues", // LinearRegressionSummary
445+
"totalIterations" // LinearRegressionSummary
432446
)
433447

434448
def invokeMethodAllowed(obj: Object, methodName: String): Object = {

0 commit comments

Comments
 (0)