Skip to content

Commit e70c3fe

Browse files
committed
Average the gradient from workers
1 parent 74b4b7c commit e70c3fe

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.apache.flink.ml.common.ps.training.TrainingContext;
3939
import org.apache.flink.ml.common.ps.training.TrainingUtils;
4040
import org.apache.flink.ml.common.updater.FTRL;
41+
import org.apache.flink.ml.linalg.BLAS;
4142
import org.apache.flink.ml.linalg.Vectors;
4243
import org.apache.flink.ml.param.Param;
4344
import org.apache.flink.ml.util.ParamUtils;
@@ -150,7 +151,13 @@ public LogisticRegressionModel fit(Table... inputs) {
150151
.setTerminationCriteria(
151152
(SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>)
152153
o -> o.iterationId >= getMaxIter());
153-
FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet());
154+
FTRL ftrl =
155+
new FTRL(
156+
getAlpha(),
157+
getBeta(),
158+
getReg(),
159+
getElasticNet(),
160+
trainData.getParallelism());
154161

155162
DataStream<Tuple3<Long, Long, double[]>> rawModelData =
156163
TrainingUtils.train(
@@ -274,6 +281,7 @@ private double[] computeGradient(
274281
for (int i = 0; i < sortedBatchIndices.length; i++) {
275282
cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]);
276283
}
284+
BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues));
277285
return cumGradientValues;
278286
}
279287

flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private static class RequestsIterator implements Iterator<Tuple3<Integer, long[]
7575

7676
public RequestsIterator(
7777
int numPss, long[] indices, @Nullable double[] values, long[] ranges) {
78-
Preconditions.checkArgument(values == null || values.length % indices.length == 0);
78+
// Preconditions.checkArgument(values == null || values.length % indices.length == 0);
7979
this.numServers = numPss;
8080
this.indices = indices;
8181
this.values = values;

flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ public class FTRL implements ModelUpdater {
3737
private final double lambda1;
3838
private final double lambda2;
3939

40+
private final int numWorkers;
41+
4042
// ------ Model data of FTRL optimizer. -----
4143
private long startIndex;
4244
private long endIndex;
@@ -48,11 +50,12 @@ public class FTRL implements ModelUpdater {
4850
private ListState<Long> boundaryState;
4951
private ListState<double[]> modelDataState;
5052

51-
public FTRL(double alpha, double beta, double lambda1, double lambda2) {
53+
public FTRL(double alpha, double beta, double lambda1, double lambda2, int numWorkers) {
5254
this.alpha = alpha;
5355
this.beta = beta;
5456
this.lambda1 = lambda1;
5557
this.lambda2 = lambda2;
58+
this.numWorkers = numWorkers;
5659
}
5760

5861
@Override
@@ -70,7 +73,7 @@ public void open(long startFeatureIndex, long endFeatureIndex) {
7073
public void handlePush(long[] keys, double[] values) {
7174
for (int i = 0; i < keys.length; i++) {
7275
int index = (int) (keys[i] - startIndex);
73-
double gi = values[i];
76+
double gi = values[i] / numWorkers;
7477
updateModelOnOneDim(gi, index, weight);
7578
}
7679
}

flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ public class LogisticRegressionWithFtrlTest {
6969

7070
@Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
7171
private final double[] expectedCoefficient =
72-
new double[] {0.3140991, -0.6776634, -0.5825635, -0.4035519};
73-
72+
new double[] {0.5287258, -1.2163098, -1.0710997, -0.8591691};
73+
private static final int MAX_ITER = 100;
74+
private static final int NUM_SERVERS = 2;
7475
private static final double TOLERANCE = 1e-7;
7576

7677
private static final List<Row> trainRows =
@@ -100,14 +101,16 @@ public class LogisticRegressionWithFtrlTest {
100101
Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {15, 4}), 1., 5.));
101102

102103
private StreamTableEnvironment tEnv;
104+
private StreamExecutionEnvironment env;
103105
private Table trainTable;
104106
private Table testTable;
105107
private DataFrame testDataFrame;
106108

107109
@Before
108110
public void before() {
109-
StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment();
111+
env = TestUtils.getExecutionEnvironment();
110112
tEnv = StreamTableEnvironment.create(env);
113+
111114
trainTable =
112115
tEnv.fromDataStream(
113116
env.fromCollection(
@@ -227,16 +230,17 @@ public void testOutputSchema() {
227230
@Test
228231
@SuppressWarnings("unchecked")
229232
public void testGetModelData() throws Exception {
230-
int numServers = 2;
233+
// Fix the parallelism as one for stability tests.
234+
env.setParallelism(1);
231235
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
232-
new LogisticRegressionWithFtrl().setNumServers(numServers).setNumServerCores(1);
236+
new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS);
233237
LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable);
234238
List<LogisticRegressionModelData> modelData =
235239
IteratorUtils.toList(
236240
LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0])
237241
.executeAndCollect());
238242

239-
assertEquals(numServers, modelData.size());
243+
assertEquals(NUM_SERVERS, modelData.size());
240244

241245
modelData.sort(Comparator.comparingLong(o -> o.startIndex));
242246

@@ -252,7 +256,7 @@ public void testGetModelData() throws Exception {
252256
@Test
253257
public void testFitAndPredict() throws Exception {
254258
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
255-
new LogisticRegressionWithFtrl().setNumServers(2);
259+
new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS);
256260
Table output = logisticRegressionWithFtrl.fit(trainTable).transform(testTable)[0];
257261
verifyPredictionResult(
258262
output,
@@ -264,7 +268,7 @@ public void testFitAndPredict() throws Exception {
264268
@Test
265269
public void testSaveLoadAndPredict() throws Exception {
266270
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
267-
new LogisticRegressionWithFtrl().setNumServers(2);
271+
new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS);
268272
logisticRegressionWithFtrl =
269273
TestUtils.saveAndReload(
270274
tEnv,
@@ -292,7 +296,7 @@ public void testSaveLoadAndPredict() throws Exception {
292296
@Test
293297
public void testSetModelData() throws Exception {
294298
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
295-
new LogisticRegressionWithFtrl().setNumServers(2);
299+
new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS);
296300
LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable);
297301

298302
LogisticRegressionModel newModel = new LogisticRegressionModel();
@@ -309,7 +313,7 @@ public void testSetModelData() throws Exception {
309313
@Test
310314
public void testSaveLoadServableAndPredict() throws Exception {
311315
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
312-
new LogisticRegressionWithFtrl().setNumServers(2);
316+
new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS);
313317
LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable);
314318

315319
LogisticRegressionModelServable servable =
@@ -330,7 +334,7 @@ public void testSaveLoadServableAndPredict() throws Exception {
330334
@Test
331335
public void testSetModelDataToServable() throws Exception {
332336
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
333-
new LogisticRegressionWithFtrl().setNumServers(2);
337+
new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS);
334338
LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable);
335339
List<byte[]> serializedModelData =
336340
IteratorUtils.toList(

0 commit comments

Comments
 (0)