Skip to content

Commit a92671c

Browse files
committed
Expand LogisticRegressionModelData as many pieces
1 parent 9d58fc7 commit a92671c

File tree

10 files changed

+94
-18
lines changed

10 files changed

+94
-18
lines changed

docs/content/docs/operators/classification/logisticregression.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ public class OnlineLogisticRegressionExample {
323323

324324
// Creates an online LogisticRegression object and initializes its parameters and initial
325325
// model data.
326-
Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L);
326+
Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L);
327327
Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData));
328328
OnlineLogisticRegression olr =
329329
new OnlineLogisticRegression()

flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ public static void main(String[] args) {
9696

9797
// Creates an online LogisticRegression object and initializes its parameters and initial
9898
// model data.
99-
Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L);
99+
Row initModelData =
100+
Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L);
100101
Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData));
101102
OnlineLogisticRegression olr =
102103
new OnlineLogisticRegression()

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import java.io.IOException;
4444
import java.util.Collections;
4545
import java.util.HashMap;
46+
import java.util.List;
4647
import java.util.Map;
4748

4849
/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
@@ -147,10 +148,16 @@ public PredictLabelFunction(String broadcastModelKey, Map<Param<?>, Object> para
147148
@Override
148149
public Row map(Row dataPoint) {
149150
if (servable == null) {
150-
LogisticRegressionModelData modelData =
151-
(LogisticRegressionModelData)
152-
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
153-
servable = new LogisticRegressionModelServable(modelData);
151+
List<LogisticRegressionModelData> modelData =
152+
getRuntimeContext().getBroadcastVariable(broadcastModelKey);
153+
154+
if (modelData.size() == 1) {
155+
servable = new LogisticRegressionModelServable(modelData.get(0));
156+
} else {
157+
LogisticRegressionModelData mergedModel =
158+
LogisticRegressionModelServable.mergePieces(modelData);
159+
servable = new LogisticRegressionModelServable(mergedModel);
160+
}
154161
ParamUtils.updateExistingParams(servable, params);
155162
}
156163
Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol());

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ public static DataStream<LogisticRegressionModelData> getModelDataStream(Table m
8989
StreamTableEnvironment tEnv =
9090
(StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
9191
return tEnv.toDataStream(modelData)
92-
.map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1)));
92+
.map(
93+
x ->
94+
new LogisticRegressionModelData(
95+
x.getFieldAs(0),
96+
x.getFieldAs(1),
97+
x.getFieldAs(2),
98+
x.getFieldAs(3)));
9399
}
94100

95101
/**
@@ -107,7 +113,10 @@ public static DataStream<byte[]> getModelDataByteStream(Table modelDataTable) {
107113
x -> {
108114
LogisticRegressionModelData modelData =
109115
new LogisticRegressionModelData(
110-
x.getFieldAs(0), x.getFieldAs(1));
116+
x.getFieldAs(0),
117+
x.getFieldAs(1),
118+
x.getFieldAs(2),
119+
x.getFieldAs(3));
111120

112121
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
113122
modelData.encode(outputStream);

flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ public void testSaveLoadAndPredict() throws Exception {
316316
tempFolder.newFolder().getAbsolutePath(),
317317
LogisticRegressionModel::load);
318318
assertEquals(
319-
Arrays.asList("coefficient", "modelVersion"),
319+
Arrays.asList("coefficient", "startIndex", "endIndex", "modelVersion"),
320320
model.getModelData()[0].getResolvedSchema().getColumnNames());
321321
Table output = model.transform(binomialDataTable)[0];
322322
verifyPredictionResult(

flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ public void before() throws Exception {
256256
new double[] {
257257
0.41233679404769874, -0.18088118293232122
258258
}),
259+
0L,
260+
2L,
259261
0L)));
260262
initSparseModel =
261263
tEnv.fromDataStream(
@@ -266,6 +268,8 @@ public void before() throws Exception {
266268
0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
267269
0.01, 0.01
268270
}),
271+
0L,
272+
10L,
269273
0L)));
270274
}
271275

flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_save_load_and_predict(self):
120120
model = regression.fit(self.binomial_data_table)
121121
self.assertEqual(
122122
model.get_model_data()[0].get_schema().get_field_names(),
123-
['coefficient', 'modelVersion'])
123+
['coefficient', "startIndex", "endIndex", 'modelVersion'])
124124
output = model.transform(self.binomial_data_table)[0]
125125
field_names = output.get_schema().get_field_names()
126126
self.verify_predict_result(

flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import org.apache.flink.ml.servable.api.TransformerServable;
2626
import org.apache.flink.ml.servable.builder.PipelineModelServable;
2727
import org.apache.flink.util.InstantiationUtil;
28-
import org.apache.flink.util.Preconditions;
2928

3029
import java.io.IOException;
3130
import java.io.InputStream;
31+
import java.io.SequenceInputStream;
3232
import java.lang.reflect.InvocationTargetException;
3333
import java.lang.reflect.Method;
3434
import java.util.ArrayList;
35+
import java.util.Collections;
3536
import java.util.HashMap;
3637
import java.util.List;
3738
import java.util.Map;
@@ -143,10 +144,10 @@ public static InputStream loadModelData(String path) throws IOException {
143144
FileSystem fileSystem = modelDataPath.getFileSystem();
144145

145146
FileStatus[] files = fileSystem.listStatus(modelDataPath);
146-
Preconditions.checkState(
147-
files.length == 1,
148-
"Only one model data file is expected in the directory %s.",
149-
path);
150-
return fileSystem.open(files[0].getPath());
147+
List<InputStream> inputStreams = new ArrayList<>();
148+
for (FileStatus file : files) {
149+
inputStreams.add(fileSystem.open(file.getPath()));
150+
}
151+
return new SequenceInputStream(Collections.enumeration(inputStreams));
151152
}
152153
}

flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,23 @@ public class LogisticRegressionModelData {
3333

3434
public DenseVector coefficient;
3535

36+
public long startIndex;
37+
38+
public long endIndex;
39+
3640
public long modelVersion;
3741

3842
public LogisticRegressionModelData() {}
3943

4044
public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
45+
this(coefficient, 0L, coefficient.size(), modelVersion);
46+
}
47+
48+
public LogisticRegressionModelData(
49+
DenseVector coefficient, long startIndex, long endIndex, long modelVersion) {
4150
this.coefficient = coefficient;
51+
this.startIndex = startIndex;
52+
this.endIndex = endIndex;
4253
this.modelVersion = modelVersion;
4354
}
4455

@@ -54,6 +65,8 @@ public void encode(OutputStream outputStream) throws IOException {
5465

5566
DenseVectorSerializer serializer = new DenseVectorSerializer();
5667
serializer.serialize(coefficient, dataOutputViewStreamWrapper);
68+
dataOutputViewStreamWrapper.writeLong(startIndex);
69+
dataOutputViewStreamWrapper.writeLong(endIndex);
5770
dataOutputViewStreamWrapper.writeLong(modelVersion);
5871
}
5972

@@ -69,8 +82,10 @@ static LogisticRegressionModelData decode(InputStream inputStream) throws IOExce
6982

7083
DenseVectorSerializer serializer = new DenseVectorSerializer();
7184
DenseVector coefficient = serializer.deserialize(dataInputViewStreamWrapper);
85+
long startIndex = dataInputViewStreamWrapper.readLong();
86+
long endIndex = dataInputViewStreamWrapper.readLong();
7287
long modelVersion = dataInputViewStreamWrapper.readLong();
7388

74-
return new LogisticRegressionModelData(coefficient, modelVersion);
89+
return new LogisticRegressionModelData(coefficient, startIndex, endIndex, modelVersion);
7590
}
7691
}

flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.flink.ml.classification.logisticregression;
2020

21+
import org.apache.flink.annotation.VisibleForTesting;
2122
import org.apache.flink.api.java.tuple.Tuple2;
2223
import org.apache.flink.ml.linalg.BLAS;
2324
import org.apache.flink.ml.linalg.DenseVector;
@@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) {
8182
public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs)
8283
throws IOException {
8384
Preconditions.checkArgument(modelDataInputs.length == 1);
85+
List<LogisticRegressionModelData> modelPieces = new ArrayList<>();
86+
while (true) {
87+
try {
88+
LogisticRegressionModelData piece =
89+
LogisticRegressionModelData.decode(modelDataInputs[0]);
90+
modelPieces.add(piece);
91+
} catch (IOException e) {
92+
// Reached the end of model stream.
93+
break;
94+
}
95+
}
8496

85-
modelData = LogisticRegressionModelData.decode(modelDataInputs[0]);
97+
modelData = mergePieces(modelPieces);
8698
return this;
8799
}
88100

101+
@VisibleForTesting
102+
public static LogisticRegressionModelData mergePieces(
103+
List<LogisticRegressionModelData> pieces) {
104+
long dim = 0;
105+
for (LogisticRegressionModelData piece : pieces) {
106+
dim = Math.max(dim, piece.endIndex);
107+
}
108+
// TODO: Add distributed inference for very large models.
109+
Preconditions.checkState(
110+
dim < Integer.MAX_VALUE,
111+
"The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference.");
112+
int intDim = (int) dim;
113+
DenseVector mergedCoefficient = new DenseVector(intDim);
114+
for (LogisticRegressionModelData piece : pieces) {
115+
int startIndex = (int) piece.startIndex;
116+
int endIndex = (int) piece.endIndex;
117+
System.arraycopy(
118+
piece.coefficient.values,
119+
0,
120+
mergedCoefficient.values,
121+
startIndex,
122+
endIndex - startIndex);
123+
}
124+
return new LogisticRegressionModelData(
125+
mergedCoefficient, 0, mergedCoefficient.size(), pieces.get(0).modelVersion);
126+
}
127+
89128
public static LogisticRegressionModelServable load(String path) throws IOException {
90129
LogisticRegressionModelServable servable =
91130
ServableReadWriteUtils.loadServableParam(

0 commit comments

Comments
 (0)