Skip to content

Commit 3138661

Browse files
committed
[FLINK-27826] Support training very high dimensional logisticRegression
1 parent 9d58fc7 commit 3138661

37 files changed

+3131
-16
lines changed

flink-ml-lib/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ under the License.
138138
<scope>test</scope>
139139
<type>test-jar</type>
140140
</dependency>
141+
<dependency>
142+
<groupId>fastutil</groupId>
143+
<artifactId>fastutil</artifactId>
144+
<version>5.0.9</version>
145+
</dependency>
146+
141147
</dependencies>
142148

143149
<build>

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import java.io.IOException;
4444
import java.util.Collections;
4545
import java.util.HashMap;
46+
import java.util.List;
4647
import java.util.Map;
4748

4849
/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
@@ -147,10 +148,16 @@ public PredictLabelFunction(String broadcastModelKey, Map<Param<?>, Object> para
147148
@Override
148149
public Row map(Row dataPoint) {
149150
if (servable == null) {
150-
LogisticRegressionModelData modelData =
151-
(LogisticRegressionModelData)
152-
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
153-
servable = new LogisticRegressionModelServable(modelData);
151+
List<LogisticRegressionModelData> modelData =
152+
getRuntimeContext().getBroadcastVariable(broadcastModelKey);
153+
154+
if (modelData.size() == 1) {
155+
servable = new LogisticRegressionModelServable(modelData.get(0));
156+
} else {
157+
LogisticRegressionModelData mergedModel =
158+
LogisticRegressionModelServable.mergePieces(modelData);
159+
servable = new LogisticRegressionModelServable(mergedModel);
160+
}
154161
ParamUtils.updateExistingParams(servable, params);
155162
}
156163
Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol());

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ public static DataStream<LogisticRegressionModelData> getModelDataStream(Table m
8989
StreamTableEnvironment tEnv =
9090
(StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
9191
return tEnv.toDataStream(modelData)
92-
.map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1)));
92+
.map(
93+
x ->
94+
new LogisticRegressionModelData(
95+
x.getFieldAs(0),
96+
x.getFieldAs(1),
97+
x.getFieldAs(2),
98+
x.getFieldAs(3)));
9399
}
94100

95101
/**
@@ -107,7 +113,10 @@ public static DataStream<byte[]> getModelDataByteStream(Table modelDataTable) {
107113
x -> {
108114
LogisticRegressionModelData modelData =
109115
new LogisticRegressionModelData(
110-
x.getFieldAs(0), x.getFieldAs(1));
116+
x.getFieldAs(0),
117+
x.getFieldAs(1),
118+
x.getFieldAs(2),
119+
x.getFieldAs(3));
111120

112121
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
113122
modelData.encode(outputStream);

0 commit comments

Comments
 (0)