Skip to content

Commit 74b4b7c

Browse files
committed
[FLINK-27826] Support training very high dimensional logisticRegression
1 parent a92671c commit 74b4b7c

32 files changed

+3020
-0
lines changed

flink-ml-lib/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ under the License.
138138
<scope>test</scope>
139139
<type>test-jar</type>
140140
</dependency>
141+
<dependency>
142+
<groupId>fastutil</groupId>
143+
<artifactId>fastutil</artifactId>
144+
<version>5.0.9</version>
145+
</dependency>
146+
141147
</dependencies>
142148

143149
<build>

flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java

Lines changed: 380 additions & 0 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.classification.logisticregression;
20+
21+
import org.apache.flink.ml.common.param.HasElasticNet;
22+
import org.apache.flink.ml.common.param.HasGlobalBatchSize;
23+
import org.apache.flink.ml.common.param.HasLabelCol;
24+
import org.apache.flink.ml.common.param.HasMaxIter;
25+
import org.apache.flink.ml.common.param.HasMultiClass;
26+
import org.apache.flink.ml.common.param.HasReg;
27+
import org.apache.flink.ml.common.param.HasTol;
28+
import org.apache.flink.ml.common.param.HasWeightCol;
29+
import org.apache.flink.ml.param.DoubleParam;
30+
import org.apache.flink.ml.param.IntParam;
31+
import org.apache.flink.ml.param.LongParam;
32+
import org.apache.flink.ml.param.Param;
33+
import org.apache.flink.ml.param.ParamValidators;
34+
35+
/** Params for {@link LogisticRegressionWithFtrl}. */
36+
public interface LogisticRegressionWithFtrlParams<T>
37+
extends HasLabelCol<T>,
38+
HasWeightCol<T>,
39+
HasGlobalBatchSize<T>,
40+
HasReg<T>,
41+
HasElasticNet<T>,
42+
HasMultiClass<T>,
43+
HasMaxIter<T>,
44+
HasTol<T>,
45+
LogisticRegressionModelParams<T> {
46+
47+
Param<Integer> NUM_SERVERS =
48+
new IntParam(
49+
"numServers",
50+
"Number of servers to store model parameters.",
51+
1,
52+
ParamValidators.gtEq(1));
53+
54+
Param<Integer> NUM_SERVER_CORES =
55+
new IntParam(
56+
"numServerCores",
57+
"number of cores that a server can use.",
58+
1,
59+
ParamValidators.gtEq(1));
60+
61+
Param<Double> ALPHA =
62+
new DoubleParam(
63+
"alpha",
64+
"The alpha parameter of FTRL optimizer.",
65+
0.1,
66+
ParamValidators.gt(0.0));
67+
68+
Param<Double> BETA =
69+
new DoubleParam(
70+
"beta", "The beta parameter of FTRL optimizer.", 0.1, ParamValidators.gt(0.0));
71+
72+
Param<Long> MODEL_DIM =
73+
new LongParam(
74+
"modelDim", "number of features of input data.", 0L, ParamValidators.gtEq(0));
75+
76+
default int getNumServers() {
77+
return get(NUM_SERVERS);
78+
}
79+
80+
default T setNumServers(Integer value) {
81+
return set(NUM_SERVERS, value);
82+
}
83+
84+
default int getNumServerCores() {
85+
return get(NUM_SERVER_CORES);
86+
}
87+
88+
default T setNumServerCores(int value) {
89+
return set(NUM_SERVER_CORES, value);
90+
}
91+
92+
default double getAlpha() {
93+
return get(ALPHA);
94+
}
95+
96+
default T setAlpha(Double value) {
97+
return set(ALPHA, value);
98+
}
99+
100+
default double getBeta() {
101+
return get(BETA);
102+
}
103+
104+
default T setBeta(Double value) {
105+
return set(BETA, value);
106+
}
107+
108+
default long getModelDim() {
109+
return get(MODEL_DIM);
110+
}
111+
112+
default T setModelDim(long value) {
113+
return set(MODEL_DIM, value);
114+
}
115+
}

flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,16 @@ public void computeGradient(
4747
dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
4848
BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient, dataPoint.getFeatures().size());
4949
}
50+
51+
@Override
52+
public double computeLoss(double label, double prediction) {
53+
double labelScaled = 2 * label - 1;
54+
return Math.log(1 + Math.exp(-prediction * labelScaled));
55+
}
56+
57+
@Override
58+
public double computeGradient(double label, double prediction) {
59+
double labelScaled = 2 * label - 1;
60+
return -labelScaled / (Math.exp(prediction * labelScaled) + 1);
61+
}
5062
}

flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,14 @@ public interface LossFunc extends Serializable {
4848
*/
4949
void computeGradient(
5050
LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient);
51+
52+
/** Computes loss using the label and the prediction. */
53+
default double computeLoss(double label, double prediction) {
54+
throw new UnsupportedOperationException("Not supported yet.");
55+
}
56+
57+
/** Computes gradient using the label and the prediction. */
58+
default double computeGradient(double label, double prediction) {
59+
throw new UnsupportedOperationException("Not supported yet.");
60+
}
5161
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.common.ps;
20+
21+
import org.apache.flink.api.common.state.ListState;
22+
import org.apache.flink.api.common.state.ListStateDescriptor;
23+
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
24+
import org.apache.flink.api.java.tuple.Tuple2;
25+
import org.apache.flink.ml.common.ps.message.ValuesPulledM;
26+
import org.apache.flink.runtime.state.StateInitializationContext;
27+
import org.apache.flink.runtime.state.StateSnapshotContext;
28+
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
29+
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
30+
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
31+
import org.apache.flink.util.Preconditions;
32+
33+
import java.util.ArrayList;
34+
import java.util.Comparator;
35+
import java.util.Iterator;
36+
import java.util.List;
37+
38+
/**
39+
* Merges the message from different servers for one pull request.
40+
*
41+
* <p>Note that for each single-thread worker, there are at exactly #numServers pieces for each pull
42+
* request in the feedback edge.
43+
*/
44+
public class MirrorWorkerOperator extends AbstractStreamOperator<byte[]>
45+
implements OneInputStreamOperator<Tuple2<Integer, byte[]>, byte[]> {
46+
private final int numServers;
47+
private int workerId;
48+
49+
/** The received messages from servers for the current pull request. */
50+
private List<ValuesPulledM> messageReceived;
51+
52+
private ListState<byte[]> messageReceivedState;
53+
54+
public MirrorWorkerOperator(int numServers) {
55+
this.numServers = numServers;
56+
}
57+
58+
@Override
59+
public void open() throws Exception {
60+
super.open();
61+
this.workerId = getRuntimeContext().getIndexOfThisSubtask();
62+
}
63+
64+
@Override
65+
public void processElement(StreamRecord<Tuple2<Integer, byte[]>> element) throws Exception {
66+
Preconditions.checkState(element.getValue().f0 == workerId);
67+
ValuesPulledM pulledModelM = ValuesPulledM.fromBytes(element.getValue().f1);
68+
messageReceived.add(pulledModelM);
69+
trySendingPulls(numServers);
70+
}
71+
72+
private void trySendingPulls(int numPieces) {
73+
if (messageReceived.size() == numPieces) {
74+
Comparator<ValuesPulledM> comparator = Comparator.comparingInt(o -> o.serverId);
75+
messageReceived.sort(comparator);
76+
int size = 0;
77+
for (ValuesPulledM pulledModelM : messageReceived) {
78+
size += pulledModelM.valuesPulled.length;
79+
}
80+
double[] answer = new double[size];
81+
int offset = 0;
82+
for (ValuesPulledM pulledModelM : messageReceived) {
83+
double[] values = pulledModelM.valuesPulled;
84+
System.arraycopy(values, 0, answer, offset, values.length);
85+
offset += values.length;
86+
}
87+
ValuesPulledM pulledModelM = new ValuesPulledM(-1, workerId, answer);
88+
output.collect(new StreamRecord<>(pulledModelM.toBytes()));
89+
messageReceived.clear();
90+
}
91+
}
92+
93+
@Override
94+
public void initializeState(StateInitializationContext context) throws Exception {
95+
super.initializeState(context);
96+
messageReceivedState =
97+
context.getOperatorStateStore()
98+
.getListState(
99+
new ListStateDescriptor<>(
100+
"messageReceivedState",
101+
PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO));
102+
messageReceived = new ArrayList<>();
103+
104+
Iterator<byte[]> iterator = messageReceivedState.get().iterator();
105+
if (iterator.hasNext()) {
106+
while (iterator.hasNext()) {
107+
messageReceived.add(ValuesPulledM.fromBytes(iterator.next()));
108+
}
109+
}
110+
}
111+
112+
@Override
113+
public void snapshotState(StateSnapshotContext context) throws Exception {
114+
super.snapshotState(context);
115+
messageReceivedState.clear();
116+
if (messageReceived.size() > 0) {
117+
for (ValuesPulledM valuesPulled : messageReceived) {
118+
messageReceivedState.add(valuesPulled.toBytes());
119+
}
120+
}
121+
}
122+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.common.ps;
20+
21+
import org.apache.flink.api.java.tuple.Tuple3;
22+
import org.apache.flink.util.Preconditions;
23+
24+
import javax.annotation.Nullable;
25+
26+
import java.util.Arrays;
27+
import java.util.Iterator;
28+
29+
/** Range partitioner for model data. */
30+
public class RangePartitioner {
31+
public final long dim;
32+
public final int numServers;
33+
public final long[] ranges;
34+
35+
public RangePartitioner(long dim, int numServers) {
36+
Preconditions.checkArgument(
37+
dim > 0,
38+
String.format(
39+
"Illegal dimension when using %s: %d",
40+
RangePartitioner.class.getSimpleName(), dim));
41+
42+
this.dim = dim;
43+
this.numServers = numServers;
44+
this.ranges = new long[numServers + 1];
45+
long shardSize = dim / numServers;
46+
47+
for (int serverId = 0; serverId < numServers; serverId++) {
48+
ranges[serverId] = shardSize * serverId;
49+
}
50+
ranges[numServers] = dim;
51+
}
52+
53+
/**
54+
* Splits the push/pull request according to the given sorted indices and the corresponding
55+
* values.
56+
*
57+
* @param indices Sorted indices of push/pull request.
58+
* @param values The push values if not null.
59+
* @return The split requests for each server task.
60+
*/
61+
public Iterator<Tuple3<Integer, long[], double[]>> splitRequest(
62+
long[] indices, @Nullable double[] values) {
63+
return new RequestsIterator(numServers, indices, values, ranges);
64+
}
65+
66+
private static class RequestsIterator implements Iterator<Tuple3<Integer, long[], double[]>> {
67+
private final int numServers;
68+
private final long[] indices;
69+
private final double[] values;
70+
private final long[] ranges;
71+
72+
private int serverId = 0;
73+
74+
private int s = 0;
75+
76+
public RequestsIterator(
77+
int numPss, long[] indices, @Nullable double[] values, long[] ranges) {
78+
Preconditions.checkArgument(values == null || values.length % indices.length == 0);
79+
this.numServers = numPss;
80+
this.indices = indices;
81+
this.values = values;
82+
this.ranges = ranges;
83+
}
84+
85+
@Override
86+
public boolean hasNext() {
87+
return serverId < numServers;
88+
}
89+
90+
@Override
91+
public Tuple3<Integer, long[], double[]> next() {
92+
int e = s;
93+
while (e < indices.length && indices[e] < ranges[serverId + 1]) {
94+
e++;
95+
}
96+
97+
long[] splitIndices = new long[0];
98+
double[] splitValues = values == null ? null : new double[0];
99+
if (s < e) {
100+
splitIndices = Arrays.copyOfRange(indices, s, e);
101+
splitValues = values == null ? null : Arrays.copyOfRange(values, s, e);
102+
}
103+
s = e;
104+
serverId++;
105+
return Tuple3.of(serverId - 1, splitIndices, splitValues);
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)