Skip to content

Commit f281f94

Browse files
committed
remove context object and use class members instead
1 parent 013ba5f commit f281f94

File tree

13 files changed

+188
-197
lines changed

13 files changed

+188
-197
lines changed

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.flink.ml.common.ps.training.ProcessStage;
3535
import org.apache.flink.ml.common.ps.training.PullStage;
3636
import org.apache.flink.ml.common.ps.training.PushStage;
37+
import org.apache.flink.ml.common.ps.training.SerializableConsumer;
3738
import org.apache.flink.ml.common.ps.training.TrainingContext;
3839
import org.apache.flink.ml.common.ps.training.TrainingUtils;
3940
import org.apache.flink.ml.common.updater.FTRL;
@@ -51,6 +52,7 @@
5152
import org.apache.flink.types.Row;
5253
import org.apache.flink.util.Preconditions;
5354
import org.apache.flink.util.function.SerializableFunction;
55+
import org.apache.flink.util.function.SerializableSupplier;
5456

5557
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
5658
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
@@ -128,15 +130,25 @@ public LogisticRegressionModel fit(Table... inputs) {
128130
.map((MapFunction<Long, Long>) value -> value + 1);
129131
}
130132

131-
IterationStageList<LogisticRegressionTrainingContext> iterationStages =
132-
new IterationStageList<>(new LogisticRegressionTrainingContext(getParamMap()));
133+
LogisticRegressionWithFtrlTrainingContext trainingContext =
134+
new LogisticRegressionWithFtrlTrainingContext(getParamMap());
135+
136+
IterationStageList<LogisticRegressionWithFtrlTrainingContext> iterationStages =
137+
new IterationStageList<>(trainingContext);
133138
iterationStages
134139
.addTrainingStage(new ComputeIndices())
135-
.addTrainingStage(new PullStage("pullIndices"))
140+
.addTrainingStage(
141+
new PullStage(
142+
(SerializableSupplier<long[]>) () -> trainingContext.pullIndices,
143+
(SerializableConsumer<double[]>)
144+
x -> trainingContext.pulledValues = x))
136145
.addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE))
137-
.addTrainingStage(new PushStage("pushGradient"))
146+
.addTrainingStage(
147+
new PushStage(
148+
(SerializableSupplier<long[]>) () -> trainingContext.pushIndices,
149+
(SerializableSupplier<double[]>) () -> trainingContext.pushValues))
138150
.setTerminationCriteria(
139-
(SerializableFunction<LogisticRegressionTrainingContext, Boolean>)
151+
(SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>)
140152
o -> o.iterationId >= getMaxIter());
141153
FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet());
142154

@@ -183,16 +195,15 @@ public Map<Param<?>, Object> getParamMap() {
183195
}
184196

185197
/**
186-
* A stage that samples a batch of training data and computes the indices needed to compute
187-
* gradients.
198+
* An iteration stage that samples a batch of training data and computes the indices needed to
199+
* compute gradients.
188200
*/
189-
class ComputeIndices extends ProcessStage<LogisticRegressionTrainingContext> {
201+
class ComputeIndices extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
190202

191203
@Override
192-
public void process(LogisticRegressionTrainingContext context) throws Exception {
204+
public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception {
193205
context.readInNextBatchData();
194-
long[] indices = computeIndices(context.batchData);
195-
context.put("pullIndices", indices);
206+
context.pullIndices = computeIndices(context.batchData);
196207
}
197208

198209
public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints) {
@@ -216,24 +227,24 @@ public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints
216227
}
217228

218229
/**
219-
* A Stage that uses the pulled model parameters and batch data to compute the gradients. The
220-
* gradients are stored in context for later push.
230+
* An iteration stage that uses the pulled model values and sampled batch data to compute the
231+
* gradients.
221232
*/
222-
class ComputeGradients extends ProcessStage<LogisticRegressionTrainingContext> {
233+
class ComputeGradients extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
223234
private final LossFunc lossFunc;
224235

225236
public ComputeGradients(LossFunc lossFunc) {
226237
this.lossFunc = lossFunc;
227238
}
228239

229240
@Override
230-
public void process(LogisticRegressionTrainingContext context) throws IOException {
241+
public void process(LogisticRegressionWithFtrlTrainingContext context) throws IOException {
231242
long[] indices = ComputeIndices.computeIndices(context.batchData);
232-
double[] pulledModelValues = (double[]) context.get("pullIndices");
243+
double[] pulledModelValues = context.pulledValues;
233244
double[] gradients = computeGradient(context.batchData, indices, pulledModelValues);
234245

235-
// Saves the gradient for push.
236-
context.put("pushGradient", Tuple2.of(indices, gradients));
246+
context.pushIndices = indices;
247+
context.pushValues = gradients;
237248
}
238249

239250
private double[] computeGradient(
@@ -267,37 +278,34 @@ private static double dot(
267278
}
268279
}
269280

270-
class LogisticRegressionTrainingContext
281+
/** The context information of local computing process. */
282+
class LogisticRegressionWithFtrlTrainingContext
271283
implements TrainingContext,
272-
LogisticRegressionWithFtrlParams<LogisticRegressionTrainingContext> {
284+
LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrlTrainingContext> {
285+
/** Parameters of LogisticRegressionWithFtrl. */
286+
private final Map<Param<?>, Object> paramMap;
287+
/** Current iteration id. */
273288
int iterationId;
274-
int workerId;
275-
private int numWorkers = -1;
289+
/** The local batch size. */
276290
private int localBatchSize = -1;
277-
278-
ResettableIterator<LabeledLargePointWithWeight> trainData;
279-
private final Map<String, Object> contextObjs = new HashMap<>();
280-
281-
ListState<LabeledLargePointWithWeight> batchDataState;
291+
/** The training data. */
292+
private ResettableIterator<LabeledLargePointWithWeight> trainData;
293+
/** The batch of training data for computing gradients. */
282294
List<LabeledLargePointWithWeight> batchData;
283295

284-
private final Map<Param<?>, Object> paramMap;
296+
private ListState<LabeledLargePointWithWeight> batchDataState;
285297

286-
public LogisticRegressionTrainingContext(Map<Param<?>, Object> paramMap) {
287-
this.paramMap = paramMap;
288-
}
298+
/** The indices to pull for each iteration. */
299+
long[] pullIndices;
300+
/** The placeholder for the pulled values for each iteration. */
301+
double[] pulledValues;
302+
/** The indices to push for each iteration. */
303+
long[] pushIndices;
304+
/** The values to push for each iteration. */
305+
double[] pushValues;
289306

290-
/** Reads in next batch of training data. */
291-
public void readInNextBatchData() throws IOException {
292-
batchData.clear();
293-
int i = 0;
294-
while (i < localBatchSize && trainData.hasNext()) {
295-
batchData.add(trainData.next());
296-
i++;
297-
}
298-
if (!trainData.hasNext()) {
299-
trainData.reset();
300-
}
307+
public LogisticRegressionWithFtrlTrainingContext(Map<Param<?>, Object> paramMap) {
308+
this.paramMap = paramMap;
301309
}
302310

303311
@Override
@@ -307,8 +315,6 @@ public void setIterationId(int iterationId) {
307315

308316
@Override
309317
public void setWorldInfo(int workerId, int numWorkers) {
310-
this.workerId = workerId;
311-
this.numWorkers = numWorkers;
312318
int globalBatchSize = getGlobalBatchSize();
313319
this.localBatchSize = globalBatchSize / numWorkers;
314320
if (globalBatchSize % numWorkers > workerId) {
@@ -322,16 +328,6 @@ public void setTrainData(ResettableIterator<?> trainData) {
322328
this.trainData = (ResettableIterator<LabeledLargePointWithWeight>) trainData;
323329
}
324330

325-
@Override
326-
public void put(String key, Object value) {
327-
contextObjs.put(key, value);
328-
}
329-
330-
@Override
331-
public Object get(String key) {
332-
return contextObjs.get(key);
333-
}
334-
335331
@Override
336332
public void initializeState(StateInitializationContext context) throws Exception {
337333
batchDataState =
@@ -343,7 +339,7 @@ public void initializeState(StateInitializationContext context) throws Exception
343339

344340
Iterator<LabeledLargePointWithWeight> batchDataIterator = batchDataState.get().iterator();
345341
if (batchDataIterator.hasNext()) {
346-
contextObjs.put("batchData", IteratorUtils.toList(batchDataIterator));
342+
batchData = IteratorUtils.toList(batchDataIterator);
347343
}
348344
}
349345

@@ -359,4 +355,17 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
359355
public Map<Param<?>, Object> getParamMap() {
360356
return paramMap;
361357
}
358+
359+
/** Reads in next batch of training data. */
360+
public void readInNextBatchData() throws IOException {
361+
batchData.clear();
362+
int i = 0;
363+
while (i < localBatchSize && trainData.hasNext()) {
364+
batchData.add(trainData.next());
365+
i++;
366+
}
367+
if (!trainData.hasNext()) {
368+
trainData.reset();
369+
}
370+
}
362371
}

flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangeModelPartitioner.java renamed to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,27 @@
1818

1919
package org.apache.flink.ml.common.ps;
2020

21-
import org.apache.flink.api.java.tuple.Tuple2;
2221
import org.apache.flink.api.java.tuple.Tuple3;
23-
import org.apache.flink.ml.util.Bits;
2422
import org.apache.flink.util.Preconditions;
2523

2624
import javax.annotation.Nullable;
2725

2826
import java.util.Arrays;
2927
import java.util.Iterator;
3028

31-
/** Range partitioner for vector. */
32-
public class RangeModelPartitioner {
33-
public long dim;
34-
public int numServers;
35-
private long[] ranges;
29+
/** Range partitioner for model data. */
30+
public class RangePartitioner {
31+
public final long dim;
32+
public final int numServers;
33+
public final long[] ranges;
34+
35+
public RangePartitioner(long dim, int numServers) {
36+
Preconditions.checkArgument(
37+
dim > 0,
38+
String.format(
39+
"Illegal dimension when using %s: %d",
40+
RangePartitioner.class.getSimpleName(), dim));
3641

37-
public RangeModelPartitioner(long dim, int numServers) {
38-
Preconditions.checkArgument(dim > 0 && numServers > 0);
3942
this.dim = dim;
4043
this.numServers = numServers;
4144
this.ranges = new long[numServers + 1];
@@ -47,38 +50,6 @@ public RangeModelPartitioner(long dim, int numServers) {
4750
ranges[numServers] = dim;
4851
}
4952

50-
public RangeModelPartitioner() {}
51-
52-
public Tuple2<Long, Long> getStartAndEnd(int serverId) {
53-
return Tuple2.of(ranges[serverId], ranges[serverId + 1]);
54-
}
55-
56-
public static int getNumBytes() {
57-
return Long.BYTES + Integer.BYTES + Integer.BYTES;
58-
}
59-
60-
public int writeToBytes(byte[] bytesData, int offset) {
61-
Bits.putLong(bytesData, offset, dim);
62-
offset += Long.BYTES;
63-
Bits.putInt(bytesData, offset, numServers);
64-
offset += Integer.BYTES;
65-
return offset;
66-
}
67-
68-
public byte[] toBytes() {
69-
byte[] buffer = new byte[getNumBytes()];
70-
writeToBytes(buffer, 0);
71-
return buffer;
72-
}
73-
74-
public static RangeModelPartitioner readFromBytes(byte[] bytesData, int offset) {
75-
long dim = Bits.getLong(bytesData, offset);
76-
offset += Long.BYTES;
77-
int numPss = Bits.getInt(bytesData, offset);
78-
offset += Integer.BYTES;
79-
return new RangeModelPartitioner(dim, numPss);
80-
}
81-
8253
/**
8354
* Splits the push/pull request according to the given sorted indices and the corresponding
8455
* values.
@@ -93,45 +64,44 @@ public Iterator<Tuple3<Integer, long[], double[]>> splitRequest(
9364
}
9465

9566
private static class RequestsIterator implements Iterator<Tuple3<Integer, long[], double[]>> {
96-
private final int numPss;
67+
private final int numServers;
9768
private final long[] indices;
9869
private final double[] values;
9970
private final long[] ranges;
10071

101-
private int psId = 0;
72+
private int serverId = 0;
10273

10374
private int s = 0;
10475

10576
public RequestsIterator(
10677
int numPss, long[] indices, @Nullable double[] values, long[] ranges) {
107-
this.numPss = numPss;
78+
this.numServers = numPss;
10879
this.indices = indices;
10980
this.values = values;
11081
this.ranges = ranges;
11182
}
11283

11384
@Override
11485
public boolean hasNext() {
115-
return psId < numPss;
86+
return serverId < numServers;
11687
}
11788

11889
@Override
11990
public Tuple3<Integer, long[], double[]> next() {
12091
int e = s;
121-
while (e < indices.length && indices[e] < ranges[psId + 1]) {
92+
while (e < indices.length && indices[e] < ranges[serverId + 1]) {
12293
e++;
12394
}
12495

125-
// Also pushes the empty message for atomic of push/pull in async setting.
12696
long[] splitIndices = new long[0];
12797
double[] splitValues = values == null ? null : new double[0];
12898
if (s < e) {
12999
splitIndices = Arrays.copyOfRange(indices, s, e);
130100
splitValues = values == null ? null : Arrays.copyOfRange(values, s, e);
131101
}
132102
s = e;
133-
psId++;
134-
return Tuple3.of(psId - 1, splitIndices, splitValues);
103+
serverId++;
104+
return Tuple3.of(serverId - 1, splitIndices, splitValues);
135105
}
136106
}
137107
}

0 commit comments

Comments
 (0)