Skip to content

Commit

Permalink
[SPARK-51217][ML][CONNECT] ML model helper constructor clean up
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
ML model helper constructor clean up:
1, add comments;
2, set invalid values, e.g. empty uid, NaN efficient

### Why are the changes needed?
1, to avoid unintentionally incorrect usage;
2, to differentiate from normal models;

### Does this PR introduce _any_ user-facing change?
no, internal change

### How was this patch tested?
existing tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49956 from zhengruifeng/ml_connect_const.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 17, 2025
1 parent b3dac88 commit d75a7d6
Show file tree
Hide file tree
Showing 41 changed files with 73 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ class DecisionTreeClassificationModel private[ml] (
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Node.dummyNode, 0, 0)
private[ml] def this() = this("", Node.dummyNode, -1, -1)

override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ class FMClassificationModel private[classification] (
with FMClassifierParams with MLWritable
with HasTrainingSummary[FMClassificationTrainingSummary]{

private[ml] def this() = this(Identifiable.randomUID("fmc"),
Double.NaN, Vectors.empty, Matrices.empty)
// For ml connect only
private[ml] def this() = this("", Double.NaN, Vectors.empty, Matrices.empty)

@Since("3.0.0")
override val numClasses: Int = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,8 @@ class GBTClassificationModel private[ml](
this(uid, _trees, _treeWeights, -1, 2)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("gbtc"),
Array(new DecisionTreeRegressionModel), Array(0.0))
private[ml] def this() = this("",
Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1, -1)

@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ class LinearSVCModel private[classification] (
extends ClassificationModel[Vector, LinearSVCModel]
with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] {

private[ml] def this() = this(Identifiable.randomUID("linearsvc"), Vectors.empty, 0.0)
// For ml connect only
private[ml] def this() = this("", Vectors.empty, Double.NaN)

@Since("2.2.0")
override val numClasses: Int = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1077,8 +1077,7 @@ class LogisticRegressionModel private[spark] (
Vectors.dense(intercept), 2, isMultinomial = false)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("logreg"), Vectors.empty, 0)
private[ml] def this() = this("", Matrices.empty, Vectors.empty, -1, false)

/**
* A vector of model coefficients for "binomial" logistic regression. If this model was trained
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
with MultilayerPerceptronParams with Serializable with MLWritable
with HasTrainingSummary[MultilayerPerceptronClassificationTrainingSummary]{

private[ml] def this() = this(Identifiable.randomUID("mlpc"), Vectors.empty)
// For ml connect only
private[ml] def this() = this("", Vectors.empty)

@Since("1.6.0")
override lazy val numFeatures: Int = $(layers).head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ class NaiveBayesModel private[ml] (

import NaiveBayes._

private[ml] def this() = this(Identifiable.randomUID("nb"),
Vectors.empty, Matrices.empty, Matrices.empty)
// For ml connect only
private[ml] def this() = this("", Vectors.empty, Matrices.empty, Matrices.empty)

/**
* mllib NaiveBayes is a wrapper of ml implementation currently.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ class RandomForestClassificationModel private[ml] (
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Array(new DecisionTreeClassificationModel), 0, 0)
private[ml] def this() = this("", Array(new DecisionTreeClassificationModel), -1, -1)

@Since("1.4.0")
override def trees: Array[DecisionTreeClassificationModel] = _trees
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ class BisectingKMeansModel private[ml] (
extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable
with HasTrainingSummary[BisectingKMeansSummary] {

@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("bisecting-kmeans"),
new MLlibBisectingKMeansModel(null))
// For ml connect only
private[ml] def this() = this("", null)

@Since("3.0.0")
lazy val numFeatures: Int = parentModel.clusterCenters.head.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ class GaussianMixtureModel private[ml] (
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
with HasTrainingSummary[GaussianMixtureSummary] {

private[ml] def this() = this(Identifiable.randomUID("gmm"),
Array.emptyDoubleArray, Array.empty)
// For ml connect only
private[ml] def this() = this("", Array.emptyDoubleArray, Array.empty)

@Since("3.0.0")
lazy val numFeatures: Int = gaussians.head.mean.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,7 @@ class KMeansModel private[ml] (
with HasTrainingSummary[KMeansSummary] {

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("kmeans"),
new MLlibKMeansModel(clusterCenters = null))
private[ml] def this() = this("", null)

@Since("3.0.0")
lazy val numFeatures: Int = parentModel.clusterCenters.head.size
Expand Down
6 changes: 4 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ class LocalLDAModel private[ml] (
sparkSession: SparkSession)
extends LDAModel(uid, vocabSize, sparkSession) {

private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null)
// For ml connect only
private[ml] def this() = this("", -1, null, null)

oldLocalModel.setSeed(getSeed)

Expand Down Expand Up @@ -715,7 +716,8 @@ class DistributedLDAModel private[ml] (
private var oldLocalModelOption: Option[OldLocalLDAModel])
extends LDAModel(uid, vocabSize, sparkSession) {

private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null, None)
// For ml connect only
private[ml] def this() = this("", -1, null, null, None)

override private[clustering] def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class BucketedRandomProjectionLSHModel private[ml](
private[ml] val randMatrix: Matrix)
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {

private[ml] def this() = this(Identifiable.randomUID("brp-lsh"), Matrices.empty)
// For ml connect only
private[ml] def this() = this("", Matrices.empty)

private[ml] def this(uid: String, randUnitVectors: Array[Vector]) = {
this(uid, Matrices.fromVectors(randUnitVectors.toImmutableArraySeq))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ final class ChiSqSelectorModel private[ml] (

import ChiSqSelectorModel._

private[ml] def this() = this(
Identifiable.randomUID("chiSqSelector"), Array.emptyIntArray)
// For ml connect only
private[ml] def this() = this("", Array.emptyIntArray)

override protected def isNumericAttribute = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ class CountVectorizerModel(

import CountVectorizerModel._

private[ml] def this() = this(Identifiable.randomUID("cntVecModel"), Array.empty)
// For ml connect only
private[ml] def this() = this("", Array.empty)

@Since("1.5.0")
def this(vocabulary: Array[String]) = {
Expand Down
3 changes: 2 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ class IDFModel private[ml] (

import IDFModel._

private[ml] def this() = this(Identifiable.randomUID("idf"), null)
// For ml connect only
private[ml] def this() = this("", null)

/** @group setParam */
@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ class ImputerModel private[ml] (

import ImputerModel._

private[ml] def this() = this(Identifiable.randomUID("imputer"), null)
// For ml connect only
private[ml] def this() = this("", null)

/** @group setParam */
@Since("3.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class MaxAbsScalerModel private[ml] (

import MaxAbsScalerModel._

private[ml] def this() = this(Identifiable.randomUID("maxAbsScal"), Vectors.empty)
// For ml connect only
private[ml] def this() = this("", Vectors.empty)

/** @group setParam */
@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class MinHashLSHModel private[ml](
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {

private[ml] def this() = this(Identifiable.randomUID("mh-lsh"), Array.empty)
// For ml connect only
private[ml] def this() = this("", Array.empty)

/** @group setParam */
@Since("2.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ class MinMaxScalerModel private[ml] (

import MinMaxScalerModel._

private[ml] def this() = this(Identifiable.randomUID("minMaxScal"), Vectors.empty, Vectors.empty)
// For ml connect only
private[ml] def this() = this("", Vectors.empty, Vectors.empty)

/** @group setParam */
@Since("1.5.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ class OneHotEncoderModel private[ml] (

import OneHotEncoderModel._

private[ml] def this() = this(Identifiable.randomUID("oneHotEncoder)"), Array.emptyIntArray)
// For ml connect only
private[ml] def this() = this("", Array.emptyIntArray)

// Returns the category size for each index with `dropLast` and `handleInvalid`
// taken into account.
Expand Down
4 changes: 1 addition & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ class PCAModel private[ml] (
import PCAModel._

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("pca"),
DenseMatrix.zeros(1, 1), Vectors.empty)
private[ml] def this() = this("", Matrices.empty, Vectors.empty)

/** @group setParam */
@Since("1.5.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ class RFormulaModel private[feature](
private[ml] val pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("rFormula"), null, null)
// For ml connect only
private[ml] def this() = this("", null, null)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ class RobustScalerModel private[ml] (

import RobustScalerModel._

private[ml] def this() = this(Identifiable.randomUID("robustScal"), Vectors.empty, Vectors.empty)
// For ml connect only
private[ml] def this() = this("", Vectors.empty, Vectors.empty)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class StandardScalerModel private[ml] (

import StandardScalerModel._

private[ml] def this() = this(Identifiable.randomUID("stdScal"), Vectors.empty, Vectors.empty)
// For ml connect only
private[ml] def this() = this("", Vectors.empty, Vectors.empty)

/** @group setParam */
@Since("1.2.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ class StringIndexerModel (
def this(labelsArray: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labelsArray)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(labels = Array.empty)
private[ml] def this() = this("", Array.empty[Array[String]])

@deprecated("`labels` is deprecated and will be removed in 3.1.0. Use `labelsArray` " +
"instead.", "3.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ class TargetEncoderModel private[ml] (
@Since("4.0.0") private[ml] val stats: Array[Map[Double, (Double, Double)]])
extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("TargetEncoder"), Array.empty)
// For ml connect only
private[ml] def this() = this("", Array.empty)

/** @group setParam */
@Since("4.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ class UnivariateFeatureSelectorModel private[ml](
extends Model[UnivariateFeatureSelectorModel] with UnivariateFeatureSelectorParams
with MLWritable {

private[ml] def this() = this(
Identifiable.randomUID("UnivariateFeatureSelector"), Array.emptyIntArray)
// For ml connect only
private[ml] def this() = this("", Array.emptyIntArray)

/** @group setParam */
@Since("3.1.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class VarianceThresholdSelectorModel private[ml](
extends Model[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams
with MLWritable {

private[ml] def this() = this(
Identifiable.randomUID("VarianceThresholdSelector"), Array.emptyIntArray)
// For ml connect only
private[ml] def this() = this("", Array.emptyIntArray)

if (selectedFeatures.length >= 2) {
require(selectedFeatures.sliding(2).forall(l => l(0) < l(1)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ class VectorIndexerModel private[ml] (

import VectorIndexerModel._

private[ml] def this() = this(Identifiable.randomUID("vecIdx"), -1, Map.empty)
// For ml connect only
private[ml] def this() = this("", -1, Map.empty)

/** Java-friendly version of [[categoryMaps]] */
@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ class Word2VecModel private[ml] (

import Word2VecModel._

private[ml] def this() = this(Identifiable.randomUID("w2v"), null)
// For ml connect only
private[ml] def this() = this("", null)

/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
Expand Down
3 changes: 2 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ class FPGrowthModel private[ml] (
private val numTrainingRecords: Long)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("fpgrowth"), null, Map.empty, 0L)
// For ml connect only
private[ml] def this() = this("", null, Map.empty, -1L)

/** @group setParam */
@Since("2.2.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ class ALSModel private[ml] (
extends Model[ALSModel] with ALSModelParams with MLWritable {

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("als"), 0, null, null)
private[ml] def this() = this("", -1, null, null)

/** @group setParam */
@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ class AFTSurvivalRegressionModel private[ml] (
extends RegressionModel[Vector, AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("aftSurvReg"),
Vectors.empty, Double.NaN, Double.NaN)
// For ml connect only
private[ml] def this() = this("", Vectors.empty, Double.NaN, Double.NaN)

@Since("3.0.0")
override def numFeatures: Int = coefficients.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ class DecisionTreeRegressionModel private[ml] (
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Node.dummyNode, 0)
private[ml] def this() = this("", Node.dummyNode, -1)

override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ class FMRegressionModel private[regression] (
extends RegressionModel[Vector, FMRegressionModel]
with FMRegressorParams with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("fmr"),
Double.NaN, Vectors.empty, Matrices.empty)
// For ml connect only
private[ml] def this() = this("", Double.NaN, Vectors.empty, Matrices.empty)

@Since("3.0.0")
override val numFeatures: Int = linear.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,7 @@ class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)

// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("gbtr"),
Array(new DecisionTreeRegressionModel), Array(0.0))
private[ml] def this() = this("", Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1)

@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
Expand Down
Loading

0 comments on commit d75a7d6

Please sign in to comment.