@@ -69,8 +69,9 @@ public class LogisticRegressionWithFtrlTest {
69
69
70
70
@ Rule public final TemporaryFolder tempFolder = new TemporaryFolder ();
71
71
private final double [] expectedCoefficient =
72
- new double [] {0.3140991 , -0.6776634 , -0.5825635 , -0.4035519 };
73
-
72
+ new double [] {0.5287258 , -1.2163098 , -1.0710997 , -0.8591691 };
73
+ private static final int MAX_ITER = 100 ;
74
+ private static final int NUM_SERVERS = 2 ;
74
75
private static final double TOLERANCE = 1e-7 ;
75
76
76
77
private static final List <Row > trainRows =
@@ -100,14 +101,16 @@ public class LogisticRegressionWithFtrlTest {
100
101
Row .of (Vectors .sparse (4 , new int [] {0 , 2 }, new double [] {15 , 4 }), 1. , 5. ));
101
102
102
103
private StreamTableEnvironment tEnv ;
104
+ private StreamExecutionEnvironment env ;
103
105
private Table trainTable ;
104
106
private Table testTable ;
105
107
private DataFrame testDataFrame ;
106
108
107
109
@ Before
108
110
public void before () {
109
- StreamExecutionEnvironment env = TestUtils .getExecutionEnvironment ();
111
+ env = TestUtils .getExecutionEnvironment ();
110
112
tEnv = StreamTableEnvironment .create (env );
113
+
111
114
trainTable =
112
115
tEnv .fromDataStream (
113
116
env .fromCollection (
@@ -227,16 +230,17 @@ public void testOutputSchema() {
227
230
@ Test
228
231
@ SuppressWarnings ("unchecked" )
229
232
public void testGetModelData () throws Exception {
230
- int numServers = 2 ;
233
+ // Fix the parallelism as one for stability tests.
234
+ env .setParallelism (1 );
231
235
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
232
- new LogisticRegressionWithFtrl ().setNumServers ( numServers ). setNumServerCores ( 1 );
236
+ new LogisticRegressionWithFtrl ().setMaxIter ( MAX_ITER ). setNumServers ( NUM_SERVERS );
233
237
LogisticRegressionModel model = logisticRegressionWithFtrl .fit (trainTable );
234
238
List <LogisticRegressionModelData > modelData =
235
239
IteratorUtils .toList (
236
240
LogisticRegressionModelDataUtil .getModelDataStream (model .getModelData ()[0 ])
237
241
.executeAndCollect ());
238
242
239
- assertEquals (numServers , modelData .size ());
243
+ assertEquals (NUM_SERVERS , modelData .size ());
240
244
241
245
modelData .sort (Comparator .comparingLong (o -> o .startIndex ));
242
246
@@ -252,7 +256,7 @@ public void testGetModelData() throws Exception {
252
256
@ Test
253
257
public void testFitAndPredict () throws Exception {
254
258
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
255
- new LogisticRegressionWithFtrl ().setNumServers (2 );
259
+ new LogisticRegressionWithFtrl ().setMaxIter ( MAX_ITER ). setNumServers (NUM_SERVERS );
256
260
Table output = logisticRegressionWithFtrl .fit (trainTable ).transform (testTable )[0 ];
257
261
verifyPredictionResult (
258
262
output ,
@@ -264,7 +268,7 @@ public void testFitAndPredict() throws Exception {
264
268
@ Test
265
269
public void testSaveLoadAndPredict () throws Exception {
266
270
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
267
- new LogisticRegressionWithFtrl ().setNumServers (2 );
271
+ new LogisticRegressionWithFtrl ().setMaxIter ( MAX_ITER ). setNumServers (NUM_SERVERS );
268
272
logisticRegressionWithFtrl =
269
273
TestUtils .saveAndReload (
270
274
tEnv ,
@@ -292,7 +296,7 @@ public void testSaveLoadAndPredict() throws Exception {
292
296
@ Test
293
297
public void testSetModelData () throws Exception {
294
298
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
295
- new LogisticRegressionWithFtrl ().setNumServers (2 );
299
+ new LogisticRegressionWithFtrl ().setMaxIter ( MAX_ITER ). setNumServers (NUM_SERVERS );
296
300
LogisticRegressionModel model = logisticRegressionWithFtrl .fit (trainTable );
297
301
298
302
LogisticRegressionModel newModel = new LogisticRegressionModel ();
@@ -309,7 +313,7 @@ public void testSetModelData() throws Exception {
309
313
@ Test
310
314
public void testSaveLoadServableAndPredict () throws Exception {
311
315
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
312
- new LogisticRegressionWithFtrl ().setNumServers (2 );
316
+ new LogisticRegressionWithFtrl ().setMaxIter ( MAX_ITER ). setNumServers (NUM_SERVERS );
313
317
LogisticRegressionModel model = logisticRegressionWithFtrl .fit (trainTable );
314
318
315
319
LogisticRegressionModelServable servable =
@@ -330,7 +334,7 @@ public void testSaveLoadServableAndPredict() throws Exception {
330
334
@ Test
331
335
public void testSetModelDataToServable () throws Exception {
332
336
LogisticRegressionWithFtrl logisticRegressionWithFtrl =
333
- new LogisticRegressionWithFtrl ().setNumServers (2 );
337
+ new LogisticRegressionWithFtrl ().setMaxIter ( MAX_ITER ). setNumServers (NUM_SERVERS );
334
338
LogisticRegressionModel model = logisticRegressionWithFtrl .fit (trainTable );
335
339
List <byte []> serializedModelData =
336
340
IteratorUtils .toList (
0 commit comments