Skip to content

Commit

Permalink
[SPARK-50976][ML][PYTHON] Fix the save/load of TargetEncoder
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, Fix the save/load of `TargetEncoder`
2, hide `TargetEncoderModel.stats`

### Why are the changes needed?

1, existing implementation of `save/load` actually does not work

2, in the python side, `TargetEncoderModel.stats` return a `JavaObject` which cannot be used.
We should find a better way to expose the model coefficients.

```
In [1]: from pyspark.ml.feature import *
   ...:
   ...: df = spark.createDataFrame(
   ...: [
   ...:       (0, 3, 5.0, 0.0),
   ...:       (1, 4, 5.0, 1.0),
   ...:       (2, 3, 5.0, 0.0),
   ...:       (0, 4, 6.0, 1.0),
   ...:       (1, 3, 6.0, 0.0),
   ...:       (2, 4, 6.0, 1.0),
   ...:       (0, 3, 7.0, 0.0),
   ...:       (1, 4, 8.0, 1.0),
   ...:       (2, 3, 9.0, 0.0),
   ...: ],
   ...: schema="input1 short, input2 int, input3 double, label double",
   ...: )
   ...: encoder = TargetEncoder(
   ...: inputCols=["input1", "input2", "input3"],
   ...: outputCols=["output", "output2", "output3"],
   ...: labelCol="label",
   ...: targetType="binary",
   ...: )
   ...: model = encoder.fit(df)

In [2]: model.stats
Out[2]: JavaObject id=o92

In [5]: model.write().overwrite().save("/tmp/ta")

In [6]: TargetEncoderModel.load("/tmp/ta")
{"ts": "2025-01-24 19:06:54,598", "level": "ERROR", "logger": "DataFrameQueryContextLogger", "msg": "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `encodings` cannot be resolved. Did you mean one of the following? [`stats`]. SQLSTATE: 42703", "context": {"file":

...

AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `encodings` cannot be resolved. Did you mean one of the following? [`stats`]. SQLSTATE: 42703;
'Project ['encodings]
+- Relation [stats#37] parquet

```

### Does this PR introduce _any_ user-facing change?
No, since this algorithm was 4.0 only

### How was this patch tested?
updated test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49649 from zhengruifeng/ml_target_save_load.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Jan 24, 2025
1 parent 5db31ae commit 44966c9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] {
*/
@Since("4.0.0")
class TargetEncoderModel private[ml] (
@Since("4.0.0") override val uid: String,
@Since("4.0.0") val stats: Array[Map[Double, (Double, Double)]])
@Since("4.0.0") override val uid: String,
@Since("4.0.0") private[ml] val stats: Array[Map[Double, (Double, Double)]])
extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable {

/** @group setParam */
Expand Down Expand Up @@ -403,13 +403,18 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] {
private[TargetEncoderModel]
class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter {

private case class Data(stats: Array[Map[Double, (Double, Double)]])
private case class Data(index: Int, categories: Array[Double],
counts: Array[Double], stats: Array[Double])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.stats)
val datum = instance.stats.iterator.zipWithIndex.map { case (stat, index) =>
val (_categories, _countsAndStats) = stat.toSeq.unzip
val (_counts, _stats) = _countsAndStats.unzip
Data(index, _categories.toArray, _counts.toArray, _stats.toArray)
}.toSeq
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
sparkSession.createDataFrame(datum).write.parquet(dataPath)
}
}

Expand All @@ -420,10 +425,18 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] {
override def load(path: String): TargetEncoderModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("encodings")
.head()
val stats = data.getAs[Array[Map[Double, (Double, Double)]]](0)

val stats = sparkSession.read.parquet(dataPath)
.select("index", "categories", "counts", "stats")
.collect()
.map { row =>
val index = row.getInt(0)
val categories = row.getAs[Seq[Double]](1).toArray
val counts = row.getAs[Seq[Double]](2).toArray
val stats = row.getAs[Seq[Double]](3).toArray
(index, categories.zip(counts.zip(stats)).toMap)
}.sortBy(_._1).map(_._2)

val model = new TargetEncoderModel(metadata.uid, stats)
metadata.getAndSetParams(model)
model
Expand Down
9 changes: 0 additions & 9 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5500,15 +5500,6 @@ def setSmoothing(self, value: float) -> "TargetEncoderModel":
"""
return self._set(smoothing=value)

@property
@since("4.0.0")
def stats(self) -> List[Dict[float, Tuple[float, float]]]:
"""
Fitted statistics for each feature to being encoded.
The list contains a dictionary for each input column.
"""
return self._call_java("stats")


@inherit_doc
class Tokenizer(
Expand Down
157 changes: 16 additions & 141 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
StringIndexer,
StringIndexerModel,
TargetEncoder,
TargetEncoderModel,
VectorSizeHint,
VectorAssembler,
PCA,
Expand Down Expand Up @@ -1113,148 +1114,22 @@ def test_target_encoder_binary(self):
targetType="binary",
)
model = encoder.fit(df)
te = model.transform(df)
actual = te.drop("label").collect()
expected = [
Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3),
Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3),
Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3),
Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3),
Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3),
Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3),
Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0),
Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0),
Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0),
]
self.assertEqual(actual, expected)
te = model.setSmoothing(1.0).transform(df)
actual = te.drop("label").collect()
expected = [
Row(
input1=0,
input2=3,
input3=5.0,
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(1 - 5 / 6) * (4 / 9),
output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
),
Row(
input1=1,
input2=4,
input3=5.0,
output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
),
Row(
input1=2,
input2=3,
input3=5.0,
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(1 - 5 / 6) * (4 / 9),
output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
),
Row(
input1=0,
input2=4,
input3=6.0,
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
),
Row(
input1=1,
input2=3,
input3=6.0,
output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(1 - 5 / 6) * (4 / 9),
output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
),
Row(
input1=2,
input2=4,
input3=6.0,
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
),
Row(
input1=0,
input2=3,
input3=7.0,
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(1 - 5 / 6) * (4 / 9),
output3=(1 - 1 / 2) * (4 / 9),
),
Row(
input1=1,
input2=4,
input3=8.0,
output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
output3=(1 / 2) + (1 - 1 / 2) * (4 / 9),
),
Row(
input1=2,
input2=3,
input3=9.0,
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
output2=(1 - 5 / 6) * (4 / 9),
output3=(1 - 1 / 2) * (4 / 9),
),
]
self.assertEqual(actual, expected)

def test_target_encoder_continuous(self):
df = self.spark.createDataFrame(
[
(0, 3, 5.0, 10.0),
(1, 4, 5.0, 20.0),
(2, 3, 5.0, 30.0),
(0, 4, 6.0, 40.0),
(1, 3, 6.0, 50.0),
(2, 4, 6.0, 60.0),
(0, 3, 7.0, 70.0),
(1, 4, 8.0, 80.0),
(2, 3, 9.0, 90.0),
],
schema="input1 short, input2 int, input3 double, label double",
)
encoder = TargetEncoder(
inputCols=["input1", "input2", "input3"],
outputCols=["output", "output2", "output3"],
labelCol="label",
targetType="continuous",
output = model.transform(df)
self.assertEqual(
output.columns,
["input1", "input2", "input3", "label", "output", "output2", "output3"],
)
model = encoder.fit(df)
te = model.transform(df)
actual = te.drop("label").collect()
expected = [
Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0, output3=20.0),
Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=20.0),
Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0, output3=20.0),
Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0, output3=50.0),
Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0),
Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0, output3=50.0),
Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0, output3=70.0),
Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=80.0),
Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0),
]
self.assertEqual(actual, expected)
te = model.setSmoothing(1.0).transform(df)
actual = te.drop("label").collect()
expected = [
Row(input1=0, input2=3, input3=5.0, output1=42.5, output2=50.0, output3=27.5),
Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=27.5),
Row(input1=2, input2=3, input3=5.0, output1=57.5, output2=50.0, output3=27.5),
Row(input1=0, input2=4, input3=6.0, output1=42.5, output2=50.0, output3=50.0),
Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0),
Row(input1=2, input2=4, input3=6.0, output1=57.5, output2=50.0, output3=50.0),
Row(input1=0, input2=3, input3=7.0, output1=42.5, output2=50.0, output3=60.0),
Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=65.0),
Row(input1=2, input2=3, input3=9.0, output1=57.5, output2=50.0, output3=70.0),
]
self.assertEqual(actual, expected)
self.assertEqual(output.count(), 9)

# save & load
with tempfile.TemporaryDirectory(prefix="target_encoder") as d:
encoder.write().overwrite().save(d)
encoder2 = TargetEncoder.load(d)
self.assertEqual(str(encoder), str(encoder2))

model.write().overwrite().save(d)
model2 = TargetEncoderModel.load(d)
self.assertEqual(str(model), str(model2))

def test_vector_size_hint(self):
df = self.spark.createDataFrame(
Expand Down

0 comments on commit 44966c9

Please sign in to comment.