|
20 | 20 |
|
21 | 21 | import org.apache.flink.api.common.functions.MapFunction;
|
22 | 22 | import org.apache.flink.api.common.functions.ReduceFunction;
|
23 |
| -import org.apache.flink.api.common.state.ListState; |
24 |
| -import org.apache.flink.api.common.state.ListStateDescriptor; |
25 | 23 | import org.apache.flink.api.common.typeinfo.TypeInformation;
|
26 | 24 | import org.apache.flink.api.java.tuple.Tuple2;
|
27 | 25 | import org.apache.flink.api.java.tuple.Tuple3;
|
28 | 26 | import org.apache.flink.ml.api.Estimator;
|
29 | 27 | import org.apache.flink.ml.common.datastream.DataStreamUtils;
|
30 | 28 | import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
|
31 | 29 | import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
|
32 |
| -import org.apache.flink.ml.common.lossfunc.LossFunc; |
| 30 | +import org.apache.flink.ml.common.ps.training.ComputeGradients; |
| 31 | +import org.apache.flink.ml.common.ps.training.ComputeIndices; |
33 | 32 | import org.apache.flink.ml.common.ps.training.IterationStageList;
|
34 |
| -import org.apache.flink.ml.common.ps.training.ProcessStage; |
| 33 | +import org.apache.flink.ml.common.ps.training.MiniBatchMLSession; |
35 | 34 | import org.apache.flink.ml.common.ps.training.PullStage;
|
36 | 35 | import org.apache.flink.ml.common.ps.training.PushStage;
|
37 | 36 | import org.apache.flink.ml.common.ps.training.SerializableConsumer;
|
38 |
| -import org.apache.flink.ml.common.ps.training.TrainingContext; |
39 | 37 | import org.apache.flink.ml.common.ps.training.TrainingUtils;
|
40 |
| -import org.apache.flink.ml.common.updater.FTRL; |
41 |
| -import org.apache.flink.ml.linalg.BLAS; |
| 38 | +import org.apache.flink.ml.common.ps.updater.FTRL; |
42 | 39 | import org.apache.flink.ml.linalg.Vectors;
|
43 | 40 | import org.apache.flink.ml.param.Param;
|
44 | 41 | import org.apache.flink.ml.util.ParamUtils;
|
45 | 42 | import org.apache.flink.ml.util.ReadWriteUtils;
|
46 |
| -import org.apache.flink.runtime.state.StateInitializationContext; |
47 |
| -import org.apache.flink.runtime.state.StateSnapshotContext; |
48 |
| -import org.apache.flink.runtime.util.ResettableIterator; |
49 | 43 | import org.apache.flink.streaming.api.datastream.DataStream;
|
50 | 44 | import org.apache.flink.table.api.Table;
|
51 | 45 | import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
|
|
55 | 49 | import org.apache.flink.util.function.SerializableFunction;
|
56 | 50 | import org.apache.flink.util.function.SerializableSupplier;
|
57 | 51 |
|
58 |
| -import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; |
59 |
| -import it.unimi.dsi.fastutil.longs.LongOpenHashSet; |
60 |
| -import org.apache.commons.collections.IteratorUtils; |
61 |
| - |
62 | 52 | import java.io.IOException;
|
63 |
| -import java.util.ArrayList; |
64 |
| -import java.util.Arrays; |
65 | 53 | import java.util.HashMap;
|
66 |
| -import java.util.Iterator; |
67 |
| -import java.util.List; |
68 | 54 | import java.util.Map;
|
69 | 55 |
|
70 | 56 | /**
|
@@ -131,25 +117,27 @@ public LogisticRegressionModel fit(Table... inputs) {
|
131 | 117 | .map((MapFunction<Long, Long>) value -> value + 1);
|
132 | 118 | }
|
133 | 119 |
|
134 |
| - LogisticRegressionWithFtrlTrainingContext trainingContext = |
135 |
| - new LogisticRegressionWithFtrlTrainingContext(getParamMap()); |
| 120 | + MiniBatchMLSession<LabeledLargePointWithWeight> mlSession = |
| 121 | + new MiniBatchMLSession<>( |
| 122 | + getGlobalBatchSize(), |
| 123 | + TypeInformation.of(LabeledLargePointWithWeight.class)); |
136 | 124 |
|
137 |
| - IterationStageList<LogisticRegressionWithFtrlTrainingContext> iterationStages = |
138 |
| - new IterationStageList<>(trainingContext); |
| 125 | + IterationStageList<MiniBatchMLSession<LabeledLargePointWithWeight>> iterationStages = |
| 126 | + new IterationStageList<>(mlSession); |
139 | 127 | iterationStages
|
140 | 128 | .addStage(new ComputeIndices())
|
141 | 129 | .addStage(
|
142 | 130 | new PullStage(
|
143 |
| - (SerializableSupplier<long[]>) () -> trainingContext.pullIndices, |
144 |
| - (SerializableConsumer<double[]>) |
145 |
| - x -> trainingContext.pulledValues = x)) |
| 131 | + (SerializableSupplier<long[]>) () -> mlSession.pullIndices, |
| 132 | + (SerializableConsumer<double[]>) x -> mlSession.pulledValues = x)) |
146 | 133 | .addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE))
|
147 | 134 | .addStage(
|
148 | 135 | new PushStage(
|
149 |
| - (SerializableSupplier<long[]>) () -> trainingContext.pushIndices, |
150 |
| - (SerializableSupplier<double[]>) () -> trainingContext.pushValues)) |
| 136 | + (SerializableSupplier<long[]>) () -> mlSession.pushIndices, |
| 137 | + (SerializableSupplier<double[]>) () -> mlSession.pushValues)) |
151 | 138 | .setTerminationCriteria(
|
152 |
| - (SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>) |
| 139 | + (SerializableFunction< |
| 140 | + MiniBatchMLSession<LabeledLargePointWithWeight>, Boolean>) |
153 | 141 | o -> o.iterationId >= getMaxIter());
|
154 | 142 | FTRL ftrl =
|
155 | 143 | new FTRL(
|
@@ -194,189 +182,3 @@ public Map<Param<?>, Object> getParamMap() {
|
194 | 182 | return paramMap;
|
195 | 183 | }
|
196 | 184 | }
|
197 |
| - |
198 |
| -/** |
199 |
| - * An iteration stage that samples a batch of training data and computes the indices needed to |
200 |
| - * compute gradients. |
201 |
| - */ |
202 |
| -class ComputeIndices extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> { |
203 |
| - |
204 |
| - @Override |
205 |
| - public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception { |
206 |
| - context.readInNextBatchData(); |
207 |
| - context.pullIndices = computeIndices(context.batchData); |
208 |
| - } |
209 |
| - |
210 |
| - public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints) { |
211 |
| - LongOpenHashSet indices = new LongOpenHashSet(); |
212 |
| - for (LabeledLargePointWithWeight dataPoint : dataPoints) { |
213 |
| - long[] notZeros = dataPoint.features.f0; |
214 |
| - for (long index : notZeros) { |
215 |
| - indices.add(index); |
216 |
| - } |
217 |
| - } |
218 |
| - |
219 |
| - long[] sortedIndices = new long[indices.size()]; |
220 |
| - Iterator<Long> iterator = indices.iterator(); |
221 |
| - int i = 0; |
222 |
| - while (iterator.hasNext()) { |
223 |
| - sortedIndices[i++] = iterator.next(); |
224 |
| - } |
225 |
| - Arrays.sort(sortedIndices); |
226 |
| - return sortedIndices; |
227 |
| - } |
228 |
| -} |
229 |
| - |
230 |
| -/** |
231 |
| - * An iteration stage that uses the pulled model values and sampled batch data to compute the |
232 |
| - * gradients. |
233 |
| - */ |
234 |
| -class ComputeGradients extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> { |
235 |
| - private final LossFunc lossFunc; |
236 |
| - |
237 |
| - public ComputeGradients(LossFunc lossFunc) { |
238 |
| - this.lossFunc = lossFunc; |
239 |
| - } |
240 |
| - |
241 |
| - @Override |
242 |
| - public void process(LogisticRegressionWithFtrlTrainingContext context) throws IOException { |
243 |
| - long[] indices = ComputeIndices.computeIndices(context.batchData); |
244 |
| - double[] pulledModelValues = context.pulledValues; |
245 |
| - double[] gradients = computeGradient(context.batchData, indices, pulledModelValues); |
246 |
| - |
247 |
| - context.pushIndices = indices; |
248 |
| - context.pushValues = gradients; |
249 |
| - } |
250 |
| - |
251 |
| - private double[] computeGradient( |
252 |
| - List<LabeledLargePointWithWeight> batchData, |
253 |
| - long[] sortedBatchIndices, |
254 |
| - double[] pulledModelValues) { |
255 |
| - Long2DoubleOpenHashMap coefficient = new Long2DoubleOpenHashMap(sortedBatchIndices.length); |
256 |
| - for (int i = 0; i < sortedBatchIndices.length; i++) { |
257 |
| - coefficient.put(sortedBatchIndices[i], pulledModelValues[i]); |
258 |
| - } |
259 |
| - Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(sortedBatchIndices.length); |
260 |
| - |
261 |
| - for (LabeledLargePointWithWeight dataPoint : batchData) { |
262 |
| - double dot = dot(dataPoint.features, coefficient); |
263 |
| - double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; |
264 |
| - |
265 |
| - long[] featureIndices = dataPoint.features.f0; |
266 |
| - double[] featureValues = dataPoint.features.f1; |
267 |
| - double z; |
268 |
| - for (int i = 0; i < featureIndices.length; i++) { |
269 |
| - long currentIndex = featureIndices[i]; |
270 |
| - z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.); |
271 |
| - cumGradients.put(currentIndex, z); |
272 |
| - } |
273 |
| - } |
274 |
| - double[] cumGradientValues = new double[sortedBatchIndices.length]; |
275 |
| - for (int i = 0; i < sortedBatchIndices.length; i++) { |
276 |
| - cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]); |
277 |
| - } |
278 |
| - BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues)); |
279 |
| - return cumGradientValues; |
280 |
| - } |
281 |
| - |
282 |
| - private static double dot( |
283 |
| - Tuple2<long[], double[]> features, Long2DoubleOpenHashMap coefficient) { |
284 |
| - double dot = 0; |
285 |
| - for (int i = 0; i < features.f0.length; i++) { |
286 |
| - dot += features.f1[i] * coefficient.get(features.f0[i]); |
287 |
| - } |
288 |
| - return dot; |
289 |
| - } |
290 |
| -} |
291 |
| - |
292 |
| -/** The context information of local computing process. */ |
293 |
| -class LogisticRegressionWithFtrlTrainingContext |
294 |
| - implements TrainingContext, |
295 |
| - LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrlTrainingContext> { |
296 |
| - /** Parameters of LogisticRegressionWithFtrl. */ |
297 |
| - private final Map<Param<?>, Object> paramMap; |
298 |
| - /** Current iteration id. */ |
299 |
| - int iterationId; |
300 |
| - /** The local batch size. */ |
301 |
| - private int localBatchSize = -1; |
302 |
| - /** The training data. */ |
303 |
| - private ResettableIterator<LabeledLargePointWithWeight> trainData; |
304 |
| - /** The batch of training data for computing gradients. */ |
305 |
| - List<LabeledLargePointWithWeight> batchData; |
306 |
| - |
307 |
| - private ListState<LabeledLargePointWithWeight> batchDataState; |
308 |
| - |
309 |
| - /** The placeholder for indices to pull for each iteration. */ |
310 |
| - long[] pullIndices; |
311 |
| - /** The placeholder for the pulled values for each iteration. */ |
312 |
| - double[] pulledValues; |
313 |
| - /** The placeholder for indices to push for each iteration. */ |
314 |
| - long[] pushIndices; |
315 |
| - /** The placeholder for values to push for each iteration. */ |
316 |
| - double[] pushValues; |
317 |
| - |
318 |
| - public LogisticRegressionWithFtrlTrainingContext(Map<Param<?>, Object> paramMap) { |
319 |
| - this.paramMap = paramMap; |
320 |
| - } |
321 |
| - |
322 |
| - @Override |
323 |
| - public void setIterationId(int iterationId) { |
324 |
| - this.iterationId = iterationId; |
325 |
| - } |
326 |
| - |
327 |
| - @Override |
328 |
| - public void setWorldInfo(int workerId, int numWorkers) { |
329 |
| - int globalBatchSize = getGlobalBatchSize(); |
330 |
| - this.localBatchSize = globalBatchSize / numWorkers; |
331 |
| - if (globalBatchSize % numWorkers > workerId) { |
332 |
| - localBatchSize++; |
333 |
| - } |
334 |
| - this.batchData = new ArrayList<>(localBatchSize); |
335 |
| - } |
336 |
| - |
337 |
| - @Override |
338 |
| - public void setInputData(ResettableIterator<?> inputData) { |
339 |
| - this.trainData = (ResettableIterator<LabeledLargePointWithWeight>) inputData; |
340 |
| - } |
341 |
| - |
342 |
| - @Override |
343 |
| - public void initializeState(StateInitializationContext context) throws Exception { |
344 |
| - batchDataState = |
345 |
| - context.getOperatorStateStore() |
346 |
| - .getListState( |
347 |
| - new ListStateDescriptor<>( |
348 |
| - "batchDataState", |
349 |
| - TypeInformation.of(LabeledLargePointWithWeight.class))); |
350 |
| - |
351 |
| - Iterator<LabeledLargePointWithWeight> batchDataIterator = batchDataState.get().iterator(); |
352 |
| - if (batchDataIterator.hasNext()) { |
353 |
| - batchData = IteratorUtils.toList(batchDataIterator); |
354 |
| - } |
355 |
| - } |
356 |
| - |
357 |
| - @Override |
358 |
| - public void snapshotState(StateSnapshotContext context) throws Exception { |
359 |
| - batchDataState.clear(); |
360 |
| - if (batchData.size() > 0) { |
361 |
| - batchDataState.addAll(batchData); |
362 |
| - } |
363 |
| - } |
364 |
| - |
365 |
| - @Override |
366 |
| - public Map<Param<?>, Object> getParamMap() { |
367 |
| - return paramMap; |
368 |
| - } |
369 |
| - |
370 |
| - /** Reads in next batch of training data. */ |
371 |
| - public void readInNextBatchData() throws IOException { |
372 |
| - batchData.clear(); |
373 |
| - int i = 0; |
374 |
| - while (i < localBatchSize && trainData.hasNext()) { |
375 |
| - batchData.add(trainData.next()); |
376 |
| - i++; |
377 |
| - } |
378 |
| - if (!trainData.hasNext()) { |
379 |
| - trainData.reset(); |
380 |
| - } |
381 |
| - } |
382 |
| -} |
0 commit comments