Skip to content

Commit 3966321

Browse files
committed
Support pull/push value as array
1 parent e70c3fe commit 3966321

File tree

15 files changed

+234
-207
lines changed

15 files changed

+234
-207
lines changed

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,14 @@ public LogisticRegressionModel fit(Table... inputs) {
137137
IterationStageList<LogisticRegressionWithFtrlTrainingContext> iterationStages =
138138
new IterationStageList<>(trainingContext);
139139
iterationStages
140-
.addTrainingStage(new ComputeIndices())
141-
.addTrainingStage(
140+
.addStage(new ComputeIndices())
141+
.addStage(
142142
new PullStage(
143143
(SerializableSupplier<long[]>) () -> trainingContext.pullIndices,
144144
(SerializableConsumer<double[]>)
145145
x -> trainingContext.pulledValues = x))
146-
.addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE))
147-
.addTrainingStage(
146+
.addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE))
147+
.addStage(
148148
new PushStage(
149149
(SerializableSupplier<long[]>) () -> trainingContext.pushIndices,
150150
(SerializableSupplier<double[]>) () -> trainingContext.pushValues))
@@ -160,13 +160,7 @@ public LogisticRegressionModel fit(Table... inputs) {
160160
trainData.getParallelism());
161161

162162
DataStream<Tuple3<Long, Long, double[]>> rawModelData =
163-
TrainingUtils.train(
164-
modelDim,
165-
trainData,
166-
ftrl,
167-
iterationStages,
168-
getNumServers(),
169-
getNumServerCores());
163+
TrainingUtils.train(modelDim, trainData, ftrl, iterationStages, getNumServers());
170164

171165
final long modelVersion = 0L;
172166

@@ -341,8 +335,8 @@ public void setWorldInfo(int workerId, int numWorkers) {
341335
}
342336

343337
@Override
344-
public void setTrainData(ResettableIterator<?> trainData) {
345-
this.trainData = (ResettableIterator<LabeledLargePointWithWeight>) trainData;
338+
public void setInputData(ResettableIterator<?> inputData) {
339+
this.trainData = (ResettableIterator<LabeledLargePointWithWeight>) inputData;
346340
}
347341

348342
@Override

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ public interface LogisticRegressionWithFtrlParams<T>
5151
1,
5252
ParamValidators.gtEq(1));
5353

54-
Param<Integer> NUM_SERVER_CORES =
55-
new IntParam(
56-
"numServerCores",
57-
"number of cores that a server can use.",
58-
1,
59-
ParamValidators.gtEq(1));
60-
6154
Param<Double> ALPHA =
6255
new DoubleParam(
6356
"alpha",
@@ -81,14 +74,6 @@ default T setNumServers(Integer value) {
8174
return set(NUM_SERVERS, value);
8275
}
8376

84-
default int getNumServerCores() {
85-
return get(NUM_SERVER_CORES);
86-
}
87-
88-
default T setNumServerCores(int value) {
89-
return set(NUM_SERVER_CORES, value);
90-
}
91-
9277
default double getAlpha() {
9378
return get(ALPHA);
9479
}

flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ private static class RequestsIterator implements Iterator<Tuple3<Integer, long[]
6767
private final int numServers;
6868
private final long[] indices;
6969
private final double[] values;
70+
/**
71+
* Number of values per key. If the model data is a vector, numValuesPerKey is one. If the
72+
* model data is a matrix, numValuesPerKey is the number of columns.
73+
*/
74+
private final int numValuesPerKey;
75+
7076
private final long[] ranges;
7177

7278
private int serverId = 0;
@@ -75,11 +81,18 @@ private static class RequestsIterator implements Iterator<Tuple3<Integer, long[]
7581

7682
public RequestsIterator(
7783
int numPss, long[] indices, @Nullable double[] values, long[] ranges) {
78-
// Preconditions.checkArgument(values == null || values.length % indices.length == 0);
7984
this.numServers = numPss;
8085
this.indices = indices;
8186
this.values = values;
8287
this.ranges = ranges;
88+
if (indices.length != 0 && values != null) {
89+
numValuesPerKey = values.length / indices.length;
90+
Preconditions.checkArgument(
91+
numValuesPerKey * indices.length == values.length,
92+
"The size of values cannot be divided by size of keys.");
93+
} else {
94+
numValuesPerKey = 1;
95+
}
8396
}
8497

8598
@Override
@@ -98,7 +111,11 @@ public Tuple3<Integer, long[], double[]> next() {
98111
double[] splitValues = values == null ? null : new double[0];
99112
if (s < e) {
100113
splitIndices = Arrays.copyOfRange(indices, s, e);
101-
splitValues = values == null ? null : Arrays.copyOfRange(values, s, e);
114+
splitValues =
115+
values == null
116+
? null
117+
: Arrays.copyOfRange(
118+
values, s * numValuesPerKey, e * numValuesPerKey);
102119
}
103120
s = e;
104121
serverId++;

0 commit comments

Comments
 (0)