Skip to content

Commit 8a52d8c

Browse files
committed
resolve comments
1 parent 3966321 commit 8a52d8c

26 files changed

+560
-428
lines changed

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public Row map(Row dataPoint) {
155155
servable = new LogisticRegressionModelServable(modelData.get(0));
156156
} else {
157157
LogisticRegressionModelData mergedModel =
158-
LogisticRegressionModelServable.mergePieces(modelData);
158+
LogisticRegressionModelData.mergeSegments(modelData);
159159
servable = new LogisticRegressionModelServable(mergedModel);
160160
}
161161
ParamUtils.updateExistingParams(servable, params);

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java

Lines changed: 16 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,26 @@
2020

2121
import org.apache.flink.api.common.functions.MapFunction;
2222
import org.apache.flink.api.common.functions.ReduceFunction;
23-
import org.apache.flink.api.common.state.ListState;
24-
import org.apache.flink.api.common.state.ListStateDescriptor;
2523
import org.apache.flink.api.common.typeinfo.TypeInformation;
2624
import org.apache.flink.api.java.tuple.Tuple2;
2725
import org.apache.flink.api.java.tuple.Tuple3;
2826
import org.apache.flink.ml.api.Estimator;
2927
import org.apache.flink.ml.common.datastream.DataStreamUtils;
3028
import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
3129
import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
32-
import org.apache.flink.ml.common.lossfunc.LossFunc;
30+
import org.apache.flink.ml.common.ps.training.ComputeGradients;
31+
import org.apache.flink.ml.common.ps.training.ComputeIndices;
3332
import org.apache.flink.ml.common.ps.training.IterationStageList;
34-
import org.apache.flink.ml.common.ps.training.ProcessStage;
33+
import org.apache.flink.ml.common.ps.training.MiniBatchMLSession;
3534
import org.apache.flink.ml.common.ps.training.PullStage;
3635
import org.apache.flink.ml.common.ps.training.PushStage;
3736
import org.apache.flink.ml.common.ps.training.SerializableConsumer;
38-
import org.apache.flink.ml.common.ps.training.TrainingContext;
3937
import org.apache.flink.ml.common.ps.training.TrainingUtils;
40-
import org.apache.flink.ml.common.updater.FTRL;
41-
import org.apache.flink.ml.linalg.BLAS;
38+
import org.apache.flink.ml.common.ps.updater.FTRL;
4239
import org.apache.flink.ml.linalg.Vectors;
4340
import org.apache.flink.ml.param.Param;
4441
import org.apache.flink.ml.util.ParamUtils;
4542
import org.apache.flink.ml.util.ReadWriteUtils;
46-
import org.apache.flink.runtime.state.StateInitializationContext;
47-
import org.apache.flink.runtime.state.StateSnapshotContext;
48-
import org.apache.flink.runtime.util.ResettableIterator;
4943
import org.apache.flink.streaming.api.datastream.DataStream;
5044
import org.apache.flink.table.api.Table;
5145
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -55,16 +49,8 @@
5549
import org.apache.flink.util.function.SerializableFunction;
5650
import org.apache.flink.util.function.SerializableSupplier;
5751

58-
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
59-
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
60-
import org.apache.commons.collections.IteratorUtils;
61-
6252
import java.io.IOException;
63-
import java.util.ArrayList;
64-
import java.util.Arrays;
6553
import java.util.HashMap;
66-
import java.util.Iterator;
67-
import java.util.List;
6854
import java.util.Map;
6955

7056
/**
@@ -131,25 +117,27 @@ public LogisticRegressionModel fit(Table... inputs) {
131117
.map((MapFunction<Long, Long>) value -> value + 1);
132118
}
133119

134-
LogisticRegressionWithFtrlTrainingContext trainingContext =
135-
new LogisticRegressionWithFtrlTrainingContext(getParamMap());
120+
MiniBatchMLSession<LabeledLargePointWithWeight> mlSession =
121+
new MiniBatchMLSession<>(
122+
getGlobalBatchSize(),
123+
TypeInformation.of(LabeledLargePointWithWeight.class));
136124

137-
IterationStageList<LogisticRegressionWithFtrlTrainingContext> iterationStages =
138-
new IterationStageList<>(trainingContext);
125+
IterationStageList<MiniBatchMLSession<LabeledLargePointWithWeight>> iterationStages =
126+
new IterationStageList<>(mlSession);
139127
iterationStages
140128
.addStage(new ComputeIndices())
141129
.addStage(
142130
new PullStage(
143-
(SerializableSupplier<long[]>) () -> trainingContext.pullIndices,
144-
(SerializableConsumer<double[]>)
145-
x -> trainingContext.pulledValues = x))
131+
(SerializableSupplier<long[]>) () -> mlSession.pullIndices,
132+
(SerializableConsumer<double[]>) x -> mlSession.pulledValues = x))
146133
.addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE))
147134
.addStage(
148135
new PushStage(
149-
(SerializableSupplier<long[]>) () -> trainingContext.pushIndices,
150-
(SerializableSupplier<double[]>) () -> trainingContext.pushValues))
136+
(SerializableSupplier<long[]>) () -> mlSession.pushIndices,
137+
(SerializableSupplier<double[]>) () -> mlSession.pushValues))
151138
.setTerminationCriteria(
152-
(SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>)
139+
(SerializableFunction<
140+
MiniBatchMLSession<LabeledLargePointWithWeight>, Boolean>)
153141
o -> o.iterationId >= getMaxIter());
154142
FTRL ftrl =
155143
new FTRL(
@@ -194,189 +182,3 @@ public Map<Param<?>, Object> getParamMap() {
194182
return paramMap;
195183
}
196184
}
197-
198-
/**
199-
* An iteration stage that samples a batch of training data and computes the indices needed to
200-
* compute gradients.
201-
*/
202-
class ComputeIndices extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
203-
204-
@Override
205-
public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception {
206-
context.readInNextBatchData();
207-
context.pullIndices = computeIndices(context.batchData);
208-
}
209-
210-
public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints) {
211-
LongOpenHashSet indices = new LongOpenHashSet();
212-
for (LabeledLargePointWithWeight dataPoint : dataPoints) {
213-
long[] notZeros = dataPoint.features.f0;
214-
for (long index : notZeros) {
215-
indices.add(index);
216-
}
217-
}
218-
219-
long[] sortedIndices = new long[indices.size()];
220-
Iterator<Long> iterator = indices.iterator();
221-
int i = 0;
222-
while (iterator.hasNext()) {
223-
sortedIndices[i++] = iterator.next();
224-
}
225-
Arrays.sort(sortedIndices);
226-
return sortedIndices;
227-
}
228-
}
229-
230-
/**
231-
* An iteration stage that uses the pulled model values and sampled batch data to compute the
232-
* gradients.
233-
*/
234-
class ComputeGradients extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
235-
private final LossFunc lossFunc;
236-
237-
public ComputeGradients(LossFunc lossFunc) {
238-
this.lossFunc = lossFunc;
239-
}
240-
241-
@Override
242-
public void process(LogisticRegressionWithFtrlTrainingContext context) throws IOException {
243-
long[] indices = ComputeIndices.computeIndices(context.batchData);
244-
double[] pulledModelValues = context.pulledValues;
245-
double[] gradients = computeGradient(context.batchData, indices, pulledModelValues);
246-
247-
context.pushIndices = indices;
248-
context.pushValues = gradients;
249-
}
250-
251-
private double[] computeGradient(
252-
List<LabeledLargePointWithWeight> batchData,
253-
long[] sortedBatchIndices,
254-
double[] pulledModelValues) {
255-
Long2DoubleOpenHashMap coefficient = new Long2DoubleOpenHashMap(sortedBatchIndices.length);
256-
for (int i = 0; i < sortedBatchIndices.length; i++) {
257-
coefficient.put(sortedBatchIndices[i], pulledModelValues[i]);
258-
}
259-
Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(sortedBatchIndices.length);
260-
261-
for (LabeledLargePointWithWeight dataPoint : batchData) {
262-
double dot = dot(dataPoint.features, coefficient);
263-
double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight;
264-
265-
long[] featureIndices = dataPoint.features.f0;
266-
double[] featureValues = dataPoint.features.f1;
267-
double z;
268-
for (int i = 0; i < featureIndices.length; i++) {
269-
long currentIndex = featureIndices[i];
270-
z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.);
271-
cumGradients.put(currentIndex, z);
272-
}
273-
}
274-
double[] cumGradientValues = new double[sortedBatchIndices.length];
275-
for (int i = 0; i < sortedBatchIndices.length; i++) {
276-
cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]);
277-
}
278-
BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues));
279-
return cumGradientValues;
280-
}
281-
282-
private static double dot(
283-
Tuple2<long[], double[]> features, Long2DoubleOpenHashMap coefficient) {
284-
double dot = 0;
285-
for (int i = 0; i < features.f0.length; i++) {
286-
dot += features.f1[i] * coefficient.get(features.f0[i]);
287-
}
288-
return dot;
289-
}
290-
}
291-
292-
/** The context information of local computing process. */
293-
class LogisticRegressionWithFtrlTrainingContext
294-
implements TrainingContext,
295-
LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrlTrainingContext> {
296-
/** Parameters of LogisticRegressionWithFtrl. */
297-
private final Map<Param<?>, Object> paramMap;
298-
/** Current iteration id. */
299-
int iterationId;
300-
/** The local batch size. */
301-
private int localBatchSize = -1;
302-
/** The training data. */
303-
private ResettableIterator<LabeledLargePointWithWeight> trainData;
304-
/** The batch of training data for computing gradients. */
305-
List<LabeledLargePointWithWeight> batchData;
306-
307-
private ListState<LabeledLargePointWithWeight> batchDataState;
308-
309-
/** The placeholder for indices to pull for each iteration. */
310-
long[] pullIndices;
311-
/** The placeholder for the pulled values for each iteration. */
312-
double[] pulledValues;
313-
/** The placeholder for indices to push for each iteration. */
314-
long[] pushIndices;
315-
/** The placeholder for values to push for each iteration. */
316-
double[] pushValues;
317-
318-
public LogisticRegressionWithFtrlTrainingContext(Map<Param<?>, Object> paramMap) {
319-
this.paramMap = paramMap;
320-
}
321-
322-
@Override
323-
public void setIterationId(int iterationId) {
324-
this.iterationId = iterationId;
325-
}
326-
327-
@Override
328-
public void setWorldInfo(int workerId, int numWorkers) {
329-
int globalBatchSize = getGlobalBatchSize();
330-
this.localBatchSize = globalBatchSize / numWorkers;
331-
if (globalBatchSize % numWorkers > workerId) {
332-
localBatchSize++;
333-
}
334-
this.batchData = new ArrayList<>(localBatchSize);
335-
}
336-
337-
@Override
338-
public void setInputData(ResettableIterator<?> inputData) {
339-
this.trainData = (ResettableIterator<LabeledLargePointWithWeight>) inputData;
340-
}
341-
342-
@Override
343-
public void initializeState(StateInitializationContext context) throws Exception {
344-
batchDataState =
345-
context.getOperatorStateStore()
346-
.getListState(
347-
new ListStateDescriptor<>(
348-
"batchDataState",
349-
TypeInformation.of(LabeledLargePointWithWeight.class)));
350-
351-
Iterator<LabeledLargePointWithWeight> batchDataIterator = batchDataState.get().iterator();
352-
if (batchDataIterator.hasNext()) {
353-
batchData = IteratorUtils.toList(batchDataIterator);
354-
}
355-
}
356-
357-
@Override
358-
public void snapshotState(StateSnapshotContext context) throws Exception {
359-
batchDataState.clear();
360-
if (batchData.size() > 0) {
361-
batchDataState.addAll(batchData);
362-
}
363-
}
364-
365-
@Override
366-
public Map<Param<?>, Object> getParamMap() {
367-
return paramMap;
368-
}
369-
370-
/** Reads in next batch of training data. */
371-
public void readInNextBatchData() throws IOException {
372-
batchData.clear();
373-
int i = 0;
374-
while (i < localBatchSize && trainData.hasNext()) {
375-
batchData.add(trainData.next());
376-
i++;
377-
}
378-
if (!trainData.hasNext()) {
379-
trainData.reset();
380-
}
381-
}
382-
}

flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.apache.flink.api.common.state.ListStateDescriptor;
2323
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
2424
import org.apache.flink.api.java.tuple.Tuple2;
25-
import org.apache.flink.ml.common.ps.message.ValuesPulledM;
25+
import org.apache.flink.ml.common.ps.message.PulledValueM;
2626
import org.apache.flink.runtime.state.StateInitializationContext;
2727
import org.apache.flink.runtime.state.StateSnapshotContext;
2828
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
@@ -38,16 +38,16 @@
3838
/**
3939
* Merges the message from different servers for one pull request.
4040
*
41-
* <p>Note that for each single-thread worker, there are at exactly #numServers pieces for each pull
42-
* request in the feedback edge.
41+
* <p>Note that for each single-thread worker, there are at exactly #numServers segments for each
42+
* pull request in the feedback edge.
4343
*/
4444
public class MirrorWorkerOperator extends AbstractStreamOperator<byte[]>
4545
implements OneInputStreamOperator<Tuple2<Integer, byte[]>, byte[]> {
4646
private final int numServers;
4747
private int workerId;
4848

4949
/** The received messages from servers for the current pull request. */
50-
private List<ValuesPulledM> messageReceived;
50+
private List<PulledValueM> messageReceived;
5151

5252
private ListState<byte[]> messageReceivedState;
5353

@@ -64,28 +64,28 @@ public void open() throws Exception {
6464
@Override
6565
public void processElement(StreamRecord<Tuple2<Integer, byte[]>> element) throws Exception {
6666
Preconditions.checkState(element.getValue().f0 == workerId);
67-
ValuesPulledM pulledModelM = ValuesPulledM.fromBytes(element.getValue().f1);
68-
messageReceived.add(pulledModelM);
67+
PulledValueM pulledValueM = PulledValueM.fromBytes(element.getValue().f1);
68+
messageReceived.add(pulledValueM);
6969
trySendingPulls(numServers);
7070
}
7171

72-
private void trySendingPulls(int numPieces) {
73-
if (messageReceived.size() == numPieces) {
74-
Comparator<ValuesPulledM> comparator = Comparator.comparingInt(o -> o.serverId);
72+
private void trySendingPulls(int numSegments) {
73+
if (messageReceived.size() == numSegments) {
74+
Comparator<PulledValueM> comparator = Comparator.comparingInt(o -> o.serverId);
7575
messageReceived.sort(comparator);
7676
int size = 0;
77-
for (ValuesPulledM pulledModelM : messageReceived) {
78-
size += pulledModelM.valuesPulled.length;
77+
for (PulledValueM pulledValueM : messageReceived) {
78+
size += pulledValueM.values.length;
7979
}
8080
double[] answer = new double[size];
8181
int offset = 0;
82-
for (ValuesPulledM pulledModelM : messageReceived) {
83-
double[] values = pulledModelM.valuesPulled;
82+
for (PulledValueM pulledValueM : messageReceived) {
83+
double[] values = pulledValueM.values;
8484
System.arraycopy(values, 0, answer, offset, values.length);
8585
offset += values.length;
8686
}
87-
ValuesPulledM pulledModelM = new ValuesPulledM(-1, workerId, answer);
88-
output.collect(new StreamRecord<>(pulledModelM.toBytes()));
87+
PulledValueM pulledValueM = new PulledValueM(-1, workerId, answer);
88+
output.collect(new StreamRecord<>(pulledValueM.toBytes()));
8989
messageReceived.clear();
9090
}
9191
}
@@ -104,7 +104,7 @@ public void initializeState(StateInitializationContext context) throws Exception
104104
Iterator<byte[]> iterator = messageReceivedState.get().iterator();
105105
if (iterator.hasNext()) {
106106
while (iterator.hasNext()) {
107-
messageReceived.add(ValuesPulledM.fromBytes(iterator.next()));
107+
messageReceived.add(PulledValueM.fromBytes(iterator.next()));
108108
}
109109
}
110110
}
@@ -114,7 +114,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
114114
super.snapshotState(context);
115115
messageReceivedState.clear();
116116
if (messageReceived.size() > 0) {
117-
for (ValuesPulledM valuesPulled : messageReceived) {
117+
for (PulledValueM valuesPulled : messageReceived) {
118118
messageReceivedState.add(valuesPulled.toBytes());
119119
}
120120
}

0 commit comments

Comments
 (0)