34
34
import org .apache .flink .ml .common .ps .training .ProcessStage ;
35
35
import org .apache .flink .ml .common .ps .training .PullStage ;
36
36
import org .apache .flink .ml .common .ps .training .PushStage ;
37
+ import org .apache .flink .ml .common .ps .training .SerializableConsumer ;
37
38
import org .apache .flink .ml .common .ps .training .TrainingContext ;
38
39
import org .apache .flink .ml .common .ps .training .TrainingUtils ;
39
40
import org .apache .flink .ml .common .updater .FTRL ;
51
52
import org .apache .flink .types .Row ;
52
53
import org .apache .flink .util .Preconditions ;
53
54
import org .apache .flink .util .function .SerializableFunction ;
55
+ import org .apache .flink .util .function .SerializableSupplier ;
54
56
55
57
import it .unimi .dsi .fastutil .longs .Long2DoubleOpenHashMap ;
56
58
import it .unimi .dsi .fastutil .longs .LongOpenHashSet ;
@@ -128,15 +130,25 @@ public LogisticRegressionModel fit(Table... inputs) {
128
130
.map ((MapFunction <Long , Long >) value -> value + 1 );
129
131
}
130
132
131
- IterationStageList <LogisticRegressionTrainingContext > iterationStages =
132
- new IterationStageList <>(new LogisticRegressionTrainingContext (getParamMap ()));
133
+ LogisticRegressionWithFtrlTrainingContext trainingContext =
134
+ new LogisticRegressionWithFtrlTrainingContext (getParamMap ());
135
+
136
+ IterationStageList <LogisticRegressionWithFtrlTrainingContext > iterationStages =
137
+ new IterationStageList <>(trainingContext );
133
138
iterationStages
134
139
.addTrainingStage (new ComputeIndices ())
135
- .addTrainingStage (new PullStage ("pullIndices" ))
140
+ .addTrainingStage (
141
+ new PullStage (
142
+ (SerializableSupplier <long []>) () -> trainingContext .pullIndices ,
143
+ (SerializableConsumer <double []>)
144
+ x -> trainingContext .pulledValues = x ))
136
145
.addTrainingStage (new ComputeGradients (BinaryLogisticLoss .INSTANCE ))
137
- .addTrainingStage (new PushStage ("pushGradient" ))
146
+ .addTrainingStage (
147
+ new PushStage (
148
+ (SerializableSupplier <long []>) () -> trainingContext .pushIndices ,
149
+ (SerializableSupplier <double []>) () -> trainingContext .pushValues ))
138
150
.setTerminationCriteria (
139
- (SerializableFunction <LogisticRegressionTrainingContext , Boolean >)
151
+ (SerializableFunction <LogisticRegressionWithFtrlTrainingContext , Boolean >)
140
152
o -> o .iterationId >= getMaxIter ());
141
153
FTRL ftrl = new FTRL (getAlpha (), getBeta (), getReg (), getElasticNet ());
142
154
@@ -183,16 +195,15 @@ public Map<Param<?>, Object> getParamMap() {
183
195
}
184
196
185
197
/**
186
- * A stage that samples a batch of training data and computes the indices needed to compute
187
- * gradients.
198
+ * An iteration stage that samples a batch of training data and computes the indices needed to
199
+ * compute gradients.
188
200
*/
189
- class ComputeIndices extends ProcessStage <LogisticRegressionTrainingContext > {
201
+ class ComputeIndices extends ProcessStage <LogisticRegressionWithFtrlTrainingContext > {
190
202
191
203
@ Override
192
- public void process (LogisticRegressionTrainingContext context ) throws Exception {
204
+ public void process (LogisticRegressionWithFtrlTrainingContext context ) throws Exception {
193
205
context .readInNextBatchData ();
194
- long [] indices = computeIndices (context .batchData );
195
- context .put ("pullIndices" , indices );
206
+ context .pullIndices = computeIndices (context .batchData );
196
207
}
197
208
198
209
public static long [] computeIndices (List <LabeledLargePointWithWeight > dataPoints ) {
@@ -216,24 +227,24 @@ public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints
216
227
}
217
228
218
229
/**
219
- * A Stage that uses the pulled model parameters and batch data to compute the gradients. The
220
- * gradients are stored in context for later push .
230
+ * An iteration stage that uses the pulled model values and sampled batch data to compute the
231
+ * gradients.
221
232
*/
222
- class ComputeGradients extends ProcessStage <LogisticRegressionTrainingContext > {
233
+ class ComputeGradients extends ProcessStage <LogisticRegressionWithFtrlTrainingContext > {
223
234
private final LossFunc lossFunc ;
224
235
225
236
public ComputeGradients (LossFunc lossFunc ) {
226
237
this .lossFunc = lossFunc ;
227
238
}
228
239
229
240
@ Override
230
- public void process (LogisticRegressionTrainingContext context ) throws IOException {
241
+ public void process (LogisticRegressionWithFtrlTrainingContext context ) throws IOException {
231
242
long [] indices = ComputeIndices .computeIndices (context .batchData );
232
- double [] pulledModelValues = ( double []) context .get ( "pullIndices" ) ;
243
+ double [] pulledModelValues = context .pulledValues ;
233
244
double [] gradients = computeGradient (context .batchData , indices , pulledModelValues );
234
245
235
- // Saves the gradient for push.
236
- context .put ( "pushGradient" , Tuple2 . of ( indices , gradients )) ;
246
+ context . pushIndices = indices ;
247
+ context .pushValues = gradients ;
237
248
}
238
249
239
250
private double [] computeGradient (
@@ -267,37 +278,34 @@ private static double dot(
267
278
}
268
279
}
269
280
270
- class LogisticRegressionTrainingContext
281
+ /** The context information of local computing process. */
282
+ class LogisticRegressionWithFtrlTrainingContext
271
283
implements TrainingContext ,
272
- LogisticRegressionWithFtrlParams <LogisticRegressionTrainingContext > {
284
+ LogisticRegressionWithFtrlParams <LogisticRegressionWithFtrlTrainingContext > {
285
+ /** Parameters of LogisticRegressionWithFtrl. */
286
+ private final Map <Param <?>, Object > paramMap ;
287
+ /** Current iteration id. */
273
288
int iterationId ;
274
- int workerId ;
275
- private int numWorkers = -1 ;
289
+ /** The local batch size. */
276
290
private int localBatchSize = -1 ;
277
-
278
- ResettableIterator <LabeledLargePointWithWeight > trainData ;
279
- private final Map <String , Object > contextObjs = new HashMap <>();
280
-
281
- ListState <LabeledLargePointWithWeight > batchDataState ;
291
+ /** The training data. */
292
+ private ResettableIterator <LabeledLargePointWithWeight > trainData ;
293
+ /** The batch of training data for computing gradients. */
282
294
List <LabeledLargePointWithWeight > batchData ;
283
295
284
- private final Map < Param <?>, Object > paramMap ;
296
+ private ListState < LabeledLargePointWithWeight > batchDataState ;
285
297
286
- public LogisticRegressionTrainingContext (Map <Param <?>, Object > paramMap ) {
287
- this .paramMap = paramMap ;
288
- }
298
+ /** The indices to pull for each iteration. */
299
+ long [] pullIndices ;
300
+ /** The placeholder for the pulled values for each iteration. */
301
+ double [] pulledValues ;
302
+ /** The indices to push for each iteration. */
303
+ long [] pushIndices ;
304
+ /** The values to push for each iteration. */
305
+ double [] pushValues ;
289
306
290
- /** Reads in next batch of training data. */
291
- public void readInNextBatchData () throws IOException {
292
- batchData .clear ();
293
- int i = 0 ;
294
- while (i < localBatchSize && trainData .hasNext ()) {
295
- batchData .add (trainData .next ());
296
- i ++;
297
- }
298
- if (!trainData .hasNext ()) {
299
- trainData .reset ();
300
- }
307
+ public LogisticRegressionWithFtrlTrainingContext (Map <Param <?>, Object > paramMap ) {
308
+ this .paramMap = paramMap ;
301
309
}
302
310
303
311
@ Override
@@ -307,8 +315,6 @@ public void setIterationId(int iterationId) {
307
315
308
316
@ Override
309
317
public void setWorldInfo (int workerId , int numWorkers ) {
310
- this .workerId = workerId ;
311
- this .numWorkers = numWorkers ;
312
318
int globalBatchSize = getGlobalBatchSize ();
313
319
this .localBatchSize = globalBatchSize / numWorkers ;
314
320
if (globalBatchSize % numWorkers > workerId ) {
@@ -322,16 +328,6 @@ public void setTrainData(ResettableIterator<?> trainData) {
322
328
this .trainData = (ResettableIterator <LabeledLargePointWithWeight >) trainData ;
323
329
}
324
330
325
- @ Override
326
- public void put (String key , Object value ) {
327
- contextObjs .put (key , value );
328
- }
329
-
330
- @ Override
331
- public Object get (String key ) {
332
- return contextObjs .get (key );
333
- }
334
-
335
331
@ Override
336
332
public void initializeState (StateInitializationContext context ) throws Exception {
337
333
batchDataState =
@@ -343,7 +339,7 @@ public void initializeState(StateInitializationContext context) throws Exception
343
339
344
340
Iterator <LabeledLargePointWithWeight > batchDataIterator = batchDataState .get ().iterator ();
345
341
if (batchDataIterator .hasNext ()) {
346
- contextObjs . put ( " batchData" , IteratorUtils .toList (batchDataIterator ) );
342
+ batchData = IteratorUtils .toList (batchDataIterator );
347
343
}
348
344
}
349
345
@@ -359,4 +355,17 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
359
355
public Map <Param <?>, Object > getParamMap () {
360
356
return paramMap ;
361
357
}
358
+
359
+ /** Reads in next batch of training data. */
360
+ public void readInNextBatchData () throws IOException {
361
+ batchData .clear ();
362
+ int i = 0 ;
363
+ while (i < localBatchSize && trainData .hasNext ()) {
364
+ batchData .add (trainData .next ());
365
+ i ++;
366
+ }
367
+ if (!trainData .hasNext ()) {
368
+ trainData .reset ();
369
+ }
370
+ }
362
371
}
0 commit comments