Skip to content

Commit 03ba250

Browse files
committed
[SPARK-51014][ML][PYTHON][CONNECT] Support RFormula on connect
### What changes were proposed in this pull request? Support RFormula on connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithm supported on connect ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49703 from zhengruifeng/ml_connect_rformula. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]> (cherry picked from commit 9a45019) Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 5e0ddb5 commit 03ba250

File tree

5 files changed

+20
-9
lines changed

5 files changed

+20
-9
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ org.apache.spark.ml.recommendation.ALS
5151
org.apache.spark.ml.fpm.FPGrowth
5252

5353
# feature
54+
org.apache.spark.ml.feature.RFormula
5455
org.apache.spark.ml.feature.Imputer
5556
org.apache.spark.ml.feature.StandardScaler
5657
org.apache.spark.ml.feature.MaxAbsScaler

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ org.apache.spark.ml.recommendation.ALSModel
7373
org.apache.spark.ml.fpm.FPGrowthModel
7474

7575
# feature
76+
org.apache.spark.ml.feature.RFormulaModel
7677
org.apache.spark.ml.feature.ImputerModel
7778
org.apache.spark.ml.feature.StandardScalerModel
7879
org.apache.spark.ml.feature.MaxAbsScalerModel

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ class RFormulaModel private[feature](
349349
private[ml] val pipelineModel: PipelineModel)
350350
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
351351

352+
private[ml] def this() = this(Identifiable.randomUID("rFormula"), null, null)
353+
352354
@Since("2.0.0")
353355
override def transform(dataset: Dataset[_]): DataFrame = {
354356
checkCanTransform(dataset.schema)

python/pyspark/ml/tests/connect/test_parity_feature.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,6 @@ def test_count_vectorizer_with_maxDF(self):
3434
def test_count_vectorizer_from_vocab(self):
3535
super().test_count_vectorizer_from_vocab()
3636

37-
@unittest.skip("Need to support.")
38-
def test_rformula_force_index_label(self):
39-
super().test_rformula_force_index_label()
40-
41-
@unittest.skip("Need to support.")
42-
def test_rformula_string_indexer_order_type(self):
43-
super().test_rformula_string_indexer_order_type()
44-
4537
@unittest.skip("Need to support.")
4638
def test_string_indexer_handle_invalid(self):
4739
super().test_string_indexer_handle_invalid()

python/pyspark/ml/tests/test_feature.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Normalizer,
4242
Interaction,
4343
RFormula,
44+
RFormulaModel,
4445
Tokenizer,
4546
SQLTransformer,
4647
RegexTokenizer,
@@ -1295,12 +1296,26 @@ def test_rformula_string_indexer_order_type(self):
12951296
)
12961297
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
12971298
self.assertEqual(rf.getStringIndexerOrderType(), "alphabetDesc")
1298-
transformedDF = rf.fit(df).transform(df)
1299+
model = rf.fit(df)
1300+
self.assertEqual(rf.uid, model.uid)
1301+
transformedDF = model.transform(df)
12991302
observed = transformedDF.select("features").collect()
13001303
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
13011304
for i in range(0, len(expected)):
13021305
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
13031306

1307+
# save & load
1308+
with tempfile.TemporaryDirectory(prefix="rformula") as d:
1309+
rf.write().overwrite().save(d)
1310+
rf2 = RFormula.load(d)
1311+
self.assertEqual(str(rf), str(rf2))
1312+
1313+
model.write().overwrite().save(d)
1314+
model2 = RFormulaModel.load(d)
1315+
# TODO: fix str(model)
1316+
# self.assertEqual(str(model), str(model2))
1317+
self.assertEqual(model.getFormula(), model2.getFormula())
1318+
13041319
def test_string_indexer_handle_invalid(self):
13051320
df = self.spark.createDataFrame([(0, "a"), (1, "d"), (2, None)], ["id", "label"])
13061321

0 commit comments

Comments
 (0)