Skip to content

Commit

Permalink
[SPARK-51214][ML][PYTHON][CONNECT] Don't eagerly remove the cached mo…
Browse files Browse the repository at this point in the history
…dels for `fit_transform`

### What changes were proposed in this pull request?
Don't eagerly remove the cached models for `fit_transform`:
1, still keep the `Delete` ml command protobuf, but no longer call it in `__del__` in the  python client side;
2, build the ml cache with guava CacheBuilder and soft references, and specify the maximum size and time out.

### Why are the changes needed?
a common ml pipeline pattern is `fit_transform`:
```
def fit_transform(df):
    model = estimator.fit(df)
    return model.transform(df)

df2 = fit_transform(df)
df2.count()
```

existing implementation eagerly deletes the intermediate model from the ml cache, right after `fit_transform`, and thus causes NPE

```
pyspark.errors.exceptions.connect.SparkConnectGrpcException: (java.lang.NullPointerException) Cannot invoke "org.apache.spark.ml.Model.copy(org.apache.spark.ml.param.ParamMap)" because "model" is null

JVM stacktrace:
java.lang.NullPointerException
	at org.apache.spark.sql.connect.ml.ModelAttributeHelper.transform(MLHandler.scala:68)
	at org.apache.spark.sql.connect.ml.MLHandler$.transformMLRelation(MLHandler.scala:313)
	at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:231)
	at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477)
	at scala.Option.getOrElse(Option.scala:201)
	at org.apache.spark.sql.connect.service.SessionHolder.usePlanCache(SessionHolder.scala:476)
	at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:147)
	at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:133)
	at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelationalGroupedAggregate(SparkConnectPlanner.scala:2318)
	at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformAggregate(SparkConnectPlanner.scala:2299)
	at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:165)
	at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477)
```

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49948 from zhengruifeng/ml_connect_del.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 15, 2025
1 parent d6ad779 commit 09b93bd
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 16 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,11 @@
"<attribute> in <className> is not allowed to be accessed."
]
},
"CACHE_INVALID" : {
"message" : [
"Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted."
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/ml/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import unittest

from pyspark.sql import Row
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.feature import (
VectorAssembler,
Expand All @@ -26,6 +27,7 @@
MinMaxScaler,
MinMaxScalerModel,
)
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel
from pyspark.ml.clustering import KMeans, KMeansModel
from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
Expand Down Expand Up @@ -172,6 +174,24 @@ def test_clustering_pipeline(self):
self.assertEqual(str(model), str(model2))
self.assertEqual(str(model.stages), str(model2.stages))

def test_model_gc(self):
spark = self.spark
df = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
]
)

def fit_transform(df):
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
model = lr.fit(df)
return model.transform(df)

output = fit_transform(df)
self.assertEqual(output.count(), 3)


class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
pass
Expand Down
40 changes: 28 additions & 12 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,21 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
return cast(FuncT, wrapped)


# delete the object from the ml cache eagerly
def del_remote_cache(ref_id: str) -> None:
if ref_id is not None and "." not in ref_id:
try:
from pyspark.sql.connect.session import SparkSession

session = SparkSession.getActiveSession()
if session is not None:
session.client.remove_ml_cache(ref_id)
return
except Exception:
# SparkSession's down.
return


def try_remote_del(f: FuncT) -> FuncT:
"""Mark the function/property to delete a model on the server side."""

Expand All @@ -261,18 +276,19 @@ def wrapped(self: "JavaWrapper") -> Any:

if in_remote:
# Delete the model if possible
model_id = self._java_obj
if model_id is not None and "." not in model_id:
try:
from pyspark.sql.connect.session import SparkSession

session = SparkSession.getActiveSession()
if session is not None:
session.client.remove_ml_cache(model_id)
return
except Exception:
# SparkSession's down.
return
# model_id = self._java_obj
# del_remote_cache(model_id)
#
# Above codes delete the model from the ml cache eagerly, and may cause
# NPE in the server side in the case of 'fit_transform':
#
# def fit_transform(df):
# model = estimator.fit(df)
# return model.transform(df)
#
# output = fit_transform(df)
# output.show()
return
else:
return f(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.apache.spark.sql.connect.ml

import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.{ConcurrentMap, TimeUnit}

import com.google.common.cache.CacheBuilder

import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.ConnectHelper
Expand All @@ -29,8 +31,13 @@ private[connect] class MLCache extends Logging {
private val helper = new ConnectHelper()
private val helperID = "______ML_CONNECT_HELPER______"

private val cachedModel: ConcurrentHashMap[String, Object] =
new ConcurrentHashMap[String, Object]()
private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder
.newBuilder()
.softValues()
.maximumSize(MLCache.MAX_CACHED_ITEMS)
.expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES)
.build[String, Object]()
.asMap()

/**
* Cache an object into a map of MLCache, and return its key
Expand Down Expand Up @@ -76,3 +83,11 @@ private[connect] class MLCache extends Logging {
cachedModel.clear()
}
}

private[connect] object MLCache {
// The maximum number of distinct items in the cache.
private val MAX_CACHED_ITEMS = 100

// The maximum time for an item to stay in the cache.
private val CACHE_TIMEOUT_MINUTE = 60
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,9 @@ private[spark] case class MLAttributeNotAllowedException(className: String, attr
errorClass = "CONNECT_ML.ATTRIBUTE_NOT_ALLOWED",
messageParameters = Map("className" -> className, "attribute" -> attribute),
cause = null)

private[spark] case class MLCacheInvalidException(objectName: String)
extends SparkException(
errorClass = "CONNECT_ML.CACHE_INVALID",
messageParameters = Map("objectName" -> objectName),
cause = null)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ private class AttributeHelper(
val sessionHolder: SessionHolder,
val objRef: String,
val methods: Array[Method]) {
protected lazy val instance = sessionHolder.mlCache.get(objRef)
protected lazy val instance = {
val obj = sessionHolder.mlCache.get(objRef)
if (obj == null) {
throw MLCacheInvalidException(s"object $objRef")
}
obj
}
// Get the attribute by reflection
def getAttribute: Any = {
assert(methods.length >= 1)
Expand Down Expand Up @@ -181,6 +187,9 @@ private[connect] object MLHandler extends Logging {
case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model
val objId = mlCommand.getWrite.getObjRef.getId
val model = mlCache.get(objId).asInstanceOf[Model[_]]
if (model == null) {
throw MLCacheInvalidException(s"model $objId")
}
val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
MLUtils.setInstanceParams(copiedModel, mlCommand.getWrite.getParams)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,35 @@ class MLSuite extends MLHelper {
}
}

test("Exception: cannot retrieve object") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
val modelId = trainLogisticRegressionModel(sessionHolder)

// Fetch summary attribute
val accuracyCommand = proto.MlCommand
.newBuilder()
.setFetch(
proto.Fetch
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
.addMethods(proto.Fetch.Method.newBuilder().setMethod("summary"))
.addMethods(proto.Fetch.Method.newBuilder().setMethod("accuracy")))
.build()

// Successfully fetch summary.accuracy from the cached model
MLHandler.handleMlCommand(sessionHolder, accuracyCommand)

// Remove the model from cache
sessionHolder.mlCache.clear()

// No longer able to retrieve the model from cache
val e = intercept[MLCacheInvalidException] {
MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
}
val msg = e.getMessage
assert(msg.contains(s"$modelId from the ML cache"))
}

test("access the attribute which is not in allowed list") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
val modelId = trainLogisticRegressionModel(sessionHolder)
Expand Down

0 comments on commit 09b93bd

Please sign in to comment.