Skip to content

Commit

Permalink
[SPARK-51197][ML][PYTHON][CONNECT][TESTS] Unit test clean up
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Unit test clean up

### Why are the changes needed?
test code clean up

### Does this PR introduce _any_ user-facing change?
no, test-only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49927 from zhengruifeng/ml_connect_test_cleanup.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 13, 2025
1 parent 74d88b6 commit 8e39b7f
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 118 deletions.
18 changes: 3 additions & 15 deletions python/pyspark/ml/tests/connect/test_parity_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
# limitations under the License.
#

import os
import unittest

from pyspark.ml.tests.test_classification import ClassificationTestsMixin
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import ReusedConnectTestCase


class ClassificationParityTests(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def test_assert_remote_mode(self):
from pyspark.sql import is_remote

self.assertTrue(is_remote())

def tearDown(self) -> None:
self.spark.stop()
class ClassificationParityTests(ClassificationTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
Expand Down
18 changes: 3 additions & 15 deletions python/pyspark/ml/tests/connect/test_parity_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
# limitations under the License.
#

import os
import unittest

from pyspark.ml.tests.test_clustering import ClusteringTestsMixin
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import ReusedConnectTestCase


class ClusteringParityTests(ClusteringTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def test_assert_remote_mode(self):
from pyspark.sql import is_remote

self.assertTrue(is_remote())

def tearDown(self) -> None:
self.spark.stop()
class ClusteringParityTests(ClusteringTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
Expand Down
18 changes: 3 additions & 15 deletions python/pyspark/ml/tests/connect/test_parity_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
# limitations under the License.
#

import os
import unittest

from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import ReusedConnectTestCase


class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def test_assert_remote_mode(self):
from pyspark.sql import is_remote

self.assertTrue(is_remote())

def tearDown(self) -> None:
self.spark.stop()
class EvaluatorParityTests(EvaluatorTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
Expand Down
18 changes: 3 additions & 15 deletions python/pyspark/ml/tests/connect/test_parity_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
# limitations under the License.
#

import os
import unittest

from pyspark.ml.tests.test_regression import RegressionTestsMixin
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import ReusedConnectTestCase


class RegressionParityTests(RegressionTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def test_assert_remote_mode(self):
from pyspark.sql import is_remote

self.assertTrue(is_remote())

def tearDown(self) -> None:
self.spark.stop()
class RegressionParityTests(RegressionTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
Expand Down
11 changes: 4 additions & 7 deletions python/pyspark/ml/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np

from pyspark.ml.linalg import Vectors, Matrices
from pyspark.sql import SparkSession, DataFrame, Row
from pyspark.sql import DataFrame, Row
from pyspark.ml.classification import (
NaiveBayes,
NaiveBayesModel,
Expand Down Expand Up @@ -54,6 +54,7 @@
MultilayerPerceptronClassificationTrainingSummary,
)
from pyspark.ml.regression import DecisionTreeRegressionModel
from pyspark.testing.sqlutils import ReusedSQLTestCase


class ClassificationTestsMixin:
Expand Down Expand Up @@ -978,12 +979,8 @@ def test_mlp(self):
self.assertEqual(str(model), str(model2))


class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class ClassificationTests(ClassificationTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
Expand Down
10 changes: 3 additions & 7 deletions python/pyspark/ml/tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np

from pyspark.ml.linalg import Vectors, SparseVector
from pyspark.sql import SparkSession
from pyspark.ml.clustering import (
KMeans,
KMeansModel,
Expand All @@ -38,6 +37,7 @@
DistributedLDAModel,
PowerIterationClustering,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase


class ClusteringTestsMixin:
Expand Down Expand Up @@ -506,12 +506,8 @@ def test_power_iteration_clustering(self):
self.assertEqual(pic.getWeightCol(), pic2.getWeightCol())


class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class ClusteringTests(ClusteringTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
Expand Down
11 changes: 3 additions & 8 deletions python/pyspark/ml/tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
RankingEvaluator,
)
from pyspark.ml.linalg import Vectors
from pyspark.sql import Row, SparkSession
from pyspark.sql import Row
from pyspark.testing.sqlutils import ReusedSQLTestCase


class EvaluatorTestsMixin:
Expand Down Expand Up @@ -355,13 +356,7 @@ def test_regression_evaluator(self):
self.assertTrue(evaluator.isLargerBetter())


class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()

class EvaluatorTests(EvaluatorTestsMixin, ReusedSQLTestCase):
def test_evaluate_invalid_type(self):
evaluator = RegressionEvaluator(metricName="r2")
df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
from pyspark.testing.mlutils import SparkSessionTestCase
from pyspark.testing.sqlutils import ReusedSQLTestCase


class FeatureTestsMixin:
Expand Down Expand Up @@ -1772,7 +1772,7 @@ def test_min_hash_lsh(self):
self.assertEqual(str(model), str(model2))


class FeatureTests(FeatureTestsMixin, SparkSessionTestCase):
class FeatureTests(FeatureTestsMixin, ReusedSQLTestCase):
pass


Expand Down
11 changes: 4 additions & 7 deletions python/pyspark/ml/tests/test_fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import tempfile
import unittest

from pyspark.sql import SparkSession, Row
from pyspark.sql import Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
FPGrowthModel,
PrefixSpan,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase


class FPMTestsMixin:
Expand Down Expand Up @@ -99,12 +100,8 @@ def test_prefix_span(self):
self.assertEqual(head.freq, 3)


class FPMTests(FPMTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class FPMTests(FPMTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
Expand Down
87 changes: 60 additions & 27 deletions python/pyspark/ml/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np

from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession
from pyspark.ml.regression import (
AFTSurvivalRegression,
AFTSurvivalRegressionModel,
Expand All @@ -44,25 +43,10 @@
GBTRegressor,
GBTRegressionModel,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase


class RegressionTestsMixin:
@property
def df(self):
return (
self.spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.5, 3.0, Vectors.dense(2.0, 1.0)),
(0.7, 4.0, Vectors.dense(1.5, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)

def test_aft_survival(self):
spark = self.spark
df = spark.createDataFrame(
Expand Down Expand Up @@ -162,7 +146,21 @@ def test_isotonic_regression(self):
self.assertEqual(str(model), str(model2))

def test_linear_regression(self):
df = self.df
spark = self.spark
df = (
spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.5, 3.0, Vectors.dense(2.0, 1.0)),
(0.7, 4.0, Vectors.dense(1.5, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)

lr = LinearRegression(
regParam=0.0,
maxIter=2,
Expand Down Expand Up @@ -434,7 +432,20 @@ def test_factorization_machine(self):
self.assertEqual(str(model), str(model2))

def test_decision_tree_regressor(self):
df = self.df
spark = self.spark
df = (
spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.5, 3.0, Vectors.dense(2.0, 1.0)),
(0.7, 4.0, Vectors.dense(1.5, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)

dt = DecisionTreeRegressor(
maxDepth=2,
Expand Down Expand Up @@ -490,7 +501,20 @@ def test_decision_tree_regressor(self):
self.assertEqual(model.toDebugString, model2.toDebugString)

def test_gbt_regressor(self):
df = self.df
spark = self.spark
df = (
spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.5, 3.0, Vectors.dense(2.0, 1.0)),
(0.7, 4.0, Vectors.dense(1.5, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)

gbt = GBTRegressor(
maxIter=3,
Expand Down Expand Up @@ -575,7 +599,20 @@ def test_gbt_regressor(self):
self.assertEqual(model.toDebugString, model2.toDebugString)

def test_random_forest_regressor(self):
df = self.df
spark = self.spark
df = (
spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.5, 3.0, Vectors.dense(2.0, 1.0)),
(0.7, 4.0, Vectors.dense(1.5, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)

rf = RandomForestRegressor(
numTrees=3,
Expand Down Expand Up @@ -643,12 +680,8 @@ def test_random_forest_regressor(self):
self.assertEqual(model.toDebugString, model2.toDebugString)


class RegressionTests(RegressionTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
Expand Down

0 comments on commit 8e39b7f

Please sign in to comment.