Skip to content

Commit 81b8391

Browse files
committed
All infra ready except worker operator
1 parent 9d58fc7 commit 81b8391

37 files changed

+3037
-16
lines changed

flink-ml-lib/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ under the License.
138138
<scope>test</scope>
139139
<type>test-jar</type>
140140
</dependency>
141+
<dependency>
142+
<groupId>fastutil</groupId>
143+
<artifactId>fastutil</artifactId>
144+
<version>5.0.9</version>
145+
</dependency>
146+
141147
</dependencies>
142148

143149
<build>

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import java.io.IOException;
4444
import java.util.Collections;
4545
import java.util.HashMap;
46+
import java.util.List;
4647
import java.util.Map;
4748

4849
/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
@@ -147,10 +148,16 @@ public PredictLabelFunction(String broadcastModelKey, Map<Param<?>, Object> para
147148
@Override
148149
public Row map(Row dataPoint) {
149150
if (servable == null) {
150-
LogisticRegressionModelData modelData =
151-
(LogisticRegressionModelData)
152-
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
153-
servable = new LogisticRegressionModelServable(modelData);
151+
List<LogisticRegressionModelData> modelData =
152+
getRuntimeContext().getBroadcastVariable(broadcastModelKey);
153+
154+
if (modelData.size() == 1) {
155+
servable = new LogisticRegressionModelServable(modelData.get(0));
156+
} else {
157+
LogisticRegressionModelData mergedModel =
158+
LogisticRegressionModelServable.mergePieces(modelData);
159+
servable = new LogisticRegressionModelServable(mergedModel);
160+
}
154161
ParamUtils.updateExistingParams(servable, params);
155162
}
156163
Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol());

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ public static DataStream<LogisticRegressionModelData> getModelDataStream(Table m
8989
StreamTableEnvironment tEnv =
9090
(StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
9191
return tEnv.toDataStream(modelData)
92-
.map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1)));
92+
.map(
93+
x ->
94+
new LogisticRegressionModelData(
95+
x.getFieldAs(0),
96+
x.getFieldAs(1),
97+
x.getFieldAs(2),
98+
x.getFieldAs(3)));
9399
}
94100

95101
/**
@@ -107,7 +113,10 @@ public static DataStream<byte[]> getModelDataByteStream(Table modelDataTable) {
107113
x -> {
108114
LogisticRegressionModelData modelData =
109115
new LogisticRegressionModelData(
110-
x.getFieldAs(0), x.getFieldAs(1));
116+
x.getFieldAs(0),
117+
x.getFieldAs(1),
118+
x.getFieldAs(2),
119+
x.getFieldAs(3));
111120

112121
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
113122
modelData.encode(outputStream);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.classification.logisticregression;
20+
21+
import org.apache.flink.api.common.functions.MapFunction;
22+
import org.apache.flink.api.common.functions.ReduceFunction;
23+
import org.apache.flink.api.common.typeinfo.TypeInformation;
24+
import org.apache.flink.api.java.tuple.Tuple2;
25+
import org.apache.flink.api.java.tuple.Tuple3;
26+
import org.apache.flink.api.java.typeutils.ListTypeInfo;
27+
import org.apache.flink.ml.api.Estimator;
28+
import org.apache.flink.ml.common.datastream.DataStreamUtils;
29+
import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
30+
import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
31+
import org.apache.flink.ml.common.lossfunc.LossFunc;
32+
import org.apache.flink.ml.common.ps.training.IterationStageList;
33+
import org.apache.flink.ml.common.ps.training.ProcessStage;
34+
import org.apache.flink.ml.common.ps.training.PullStage;
35+
import org.apache.flink.ml.common.ps.training.PushStage;
36+
import org.apache.flink.ml.common.ps.training.TrainingContext;
37+
import org.apache.flink.ml.common.ps.training.TrainingUtils;
38+
import org.apache.flink.ml.common.updater.FTRL;
39+
import org.apache.flink.ml.linalg.Vectors;
40+
import org.apache.flink.ml.param.Param;
41+
import org.apache.flink.ml.util.ParamUtils;
42+
import org.apache.flink.ml.util.ReadWriteUtils;
43+
import org.apache.flink.streaming.api.datastream.DataStream;
44+
import org.apache.flink.table.api.Table;
45+
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
46+
import org.apache.flink.table.api.internal.TableImpl;
47+
import org.apache.flink.types.Row;
48+
import org.apache.flink.util.Preconditions;
49+
50+
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
51+
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
52+
53+
import java.io.IOException;
54+
import java.util.Arrays;
55+
import java.util.HashMap;
56+
import java.util.Iterator;
57+
import java.util.List;
58+
import java.util.Map;
59+
60+
/**
61+
* An Estimator which implements the large scale logistic regression algorithm using FTRL optimizer.
62+
*
63+
* <p>See https://en.wikipedia.org/wiki/Logistic_regression.
64+
*/
65+
public class LogisticRegressionWithFtrl
66+
implements Estimator<LogisticRegressionWithFtrl, LogisticRegressionModel>,
67+
LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrl> {
68+
69+
private final Map<Param<?>, Object> paramMap = new HashMap<>();
70+
71+
public LogisticRegressionWithFtrl() {
72+
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
73+
}
74+
75+
@Override
76+
public LogisticRegressionModel fit(Table... inputs) {
77+
Preconditions.checkArgument(inputs.length == 1);
78+
String classificationType = getMultiClass();
79+
Preconditions.checkArgument(
80+
"auto".equals(classificationType) || "binomial".equals(classificationType),
81+
"Multinomial classification is not supported yet. Supported options: [auto, binomial].");
82+
StreamTableEnvironment tEnv =
83+
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
84+
85+
DataStream<LabeledLargePointWithWeight> trainData =
86+
tEnv.toDataStream(inputs[0])
87+
.map(
88+
(MapFunction<Row, LabeledLargePointWithWeight>)
89+
dataPoint -> {
90+
double weight =
91+
getWeightCol() == null
92+
? 1.0
93+
: ((Number)
94+
dataPoint.getField(
95+
getWeightCol()))
96+
.doubleValue();
97+
double label =
98+
((Number) dataPoint.getField(getLabelCol()))
99+
.doubleValue();
100+
boolean isBinomial =
101+
Double.compare(0., label) == 0
102+
|| Double.compare(1., label) == 0;
103+
if (!isBinomial) {
104+
throw new RuntimeException(
105+
"Multinomial classification is not supported yet. Supported options: [auto, binomial].");
106+
}
107+
Tuple2<long[], double[]> features =
108+
dataPoint.getFieldAs(getFeaturesCol());
109+
return new LabeledLargePointWithWeight(
110+
features, label, weight);
111+
});
112+
113+
DataStream<Long> modelDim;
114+
if (getModelDim() > 0) {
115+
modelDim = trainData.getExecutionEnvironment().fromElements(getModelDim());
116+
} else {
117+
modelDim =
118+
DataStreamUtils.reduce(
119+
trainData.map(x -> x.features.f0[x.features.f0.length - 1]),
120+
(ReduceFunction<Long>) Math::max)
121+
.map((MapFunction<Long, Long>) value -> value + 1);
122+
}
123+
124+
IterationStageList<LabeledLargePointWithWeight> iterationStages =
125+
new IterationStageList<>();
126+
iterationStages
127+
.addTrainingStage(new ComputeIndices())
128+
.addTrainingStage(new PullStage("pullIndices"))
129+
.addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE))
130+
.addTrainingStage(new PushStage("pushGradient"))
131+
.setTerminationCriteria(context -> context.getCurrentIterationId() >= getMaxIter());
132+
133+
FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet());
134+
135+
DataStream<Tuple3<Long, Long, double[]>> rawModelData =
136+
TrainingUtils.<LabeledLargePointWithWeight>train(
137+
modelDim,
138+
trainData,
139+
ftrl,
140+
iterationStages,
141+
getGlobalBatchSize(),
142+
getNumServers(),
143+
getNumServerCores());
144+
145+
final long modelVersion = 0L;
146+
147+
DataStream<LogisticRegressionModelData> modelData =
148+
rawModelData.map(
149+
tuple3 ->
150+
new LogisticRegressionModelData(
151+
Vectors.dense(tuple3.f2),
152+
tuple3.f0,
153+
tuple3.f1,
154+
modelVersion));
155+
156+
LogisticRegressionModel model =
157+
new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
158+
ParamUtils.updateExistingParams(model, paramMap);
159+
return model;
160+
}
161+
162+
@Override
163+
public void save(String path) throws IOException {
164+
ReadWriteUtils.saveMetadata(this, path);
165+
}
166+
167+
public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, String path)
168+
throws IOException {
169+
return ReadWriteUtils.loadStageParam(path);
170+
}
171+
172+
@Override
173+
public Map<Param<?>, Object> getParamMap() {
174+
return paramMap;
175+
}
176+
}
177+
178+
/**
179+
* A stage that samples a batch of training data and computes the indices needed to compute
180+
* gradients.
181+
*/
182+
class ComputeIndices extends ProcessStage<LabeledLargePointWithWeight> {
183+
@Override
184+
public void process(TrainingContext<LabeledLargePointWithWeight> context) throws Exception {
185+
List<LabeledLargePointWithWeight> batchData = context.getNextBatchData();
186+
long[] indices = computeIndices(batchData);
187+
188+
context.put(
189+
"batchData",
190+
batchData,
191+
new ListTypeInfo<>(TypeInformation.of(LabeledLargePointWithWeight.class)));
192+
// Saves the indices for pull.
193+
context.put("pullIndices", indices);
194+
}
195+
196+
public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints) {
197+
LongOpenHashSet indices = new LongOpenHashSet();
198+
for (LabeledLargePointWithWeight dataPoint : dataPoints) {
199+
long[] notZeros = dataPoint.features.f0;
200+
for (long index : notZeros) {
201+
indices.add(index);
202+
}
203+
}
204+
205+
long[] sortedIndices = new long[indices.size()];
206+
Iterator<Long> iterator = indices.iterator();
207+
int i = 0;
208+
while (iterator.hasNext()) {
209+
sortedIndices[i++] = iterator.next();
210+
}
211+
Arrays.sort(sortedIndices);
212+
return sortedIndices;
213+
}
214+
}
215+
216+
/**
217+
* A Stage that uses the pulled model parameters and batch data to compute the gradients. The
218+
* gradients are stored in context for later push.
219+
*/
220+
class ComputeGradients extends ProcessStage<LabeledLargePointWithWeight> {
221+
222+
private final LossFunc lossFunc;
223+
224+
public ComputeGradients(LossFunc lossFunc) {
225+
this.lossFunc = lossFunc;
226+
}
227+
228+
@Override
229+
@SuppressWarnings("unchecked")
230+
public void process(TrainingContext<LabeledLargePointWithWeight> context) {
231+
List<LabeledLargePointWithWeight> batchData =
232+
(List<LabeledLargePointWithWeight>) context.get("batchData");
233+
234+
long[] indices = ComputeIndices.computeIndices(batchData);
235+
double[] pulledModelValues = (double[]) context.get("pullIndices");
236+
double[] gradients = computeGradient(batchData, indices, pulledModelValues);
237+
238+
// Saves the gradient for push.
239+
context.put("pushGradient", Tuple2.of(indices, gradients));
240+
}
241+
242+
private double[] computeGradient(
243+
List<LabeledLargePointWithWeight> batchData,
244+
long[] sortedBatchIndices,
245+
double[] pulledModelValues) {
246+
Long2DoubleOpenHashMap coefficient = new Long2DoubleOpenHashMap(sortedBatchIndices.length);
247+
for (int i = 0; i < sortedBatchIndices.length; i++) {
248+
coefficient.put(sortedBatchIndices[i], pulledModelValues[i]);
249+
}
250+
Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(sortedBatchIndices.length);
251+
252+
for (LabeledLargePointWithWeight dataPoint : batchData) {
253+
double dot = dot(dataPoint.features, coefficient);
254+
lossFunc.computeGradientWithDot(dataPoint, coefficient, cumGradients, dot);
255+
}
256+
double[] cumGradientValues = new double[sortedBatchIndices.length];
257+
for (int i = 0; i < sortedBatchIndices.length; i++) {
258+
cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]);
259+
}
260+
return cumGradientValues;
261+
}
262+
263+
private static double dot(
264+
Tuple2<long[], double[]> features, Long2DoubleOpenHashMap coefficient) {
265+
double dot = 0;
266+
for (int i = 0; i < features.f0.length; i++) {
267+
dot += features.f1[i] * coefficient.get(features.f0[i]);
268+
}
269+
return dot;
270+
}
271+
}

0 commit comments

Comments
 (0)