diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java
new file mode 100644
index 000000000..36ca83269
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.util.Bits;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * {@link Message} is responsible for encoding all information exchanged between {@link
+ * WorkerOperator} and {@link ServerOperator}. The message format follows this structure:
+ *
+ * `workerId serverId stageId keyLength keys valuesLength values`
+ *
+ *
where the message fields include the worker ID, server ID, stage ID, length of the keys, keys
+ * themselves, length of the values, and the values.
+ */
+public class Message {
+ private static final int WORKER_ID_OFFSET = 0;
+ private static final int SERVER_ID_OFFSET = Integer.BYTES;
+ private static final int STAGE_ID_OFFSET = Integer.BYTES + SERVER_ID_OFFSET;
+ private static final int KVS_OFFSET = Integer.BYTES + STAGE_ID_OFFSET;
+
+ /** The storage of message in bytes. */
+ public final byte[] bytes;
+
+ /** Constructs a message instance from the bytes. */
+ public Message(byte[] bytes) {
+ this.bytes = bytes;
+ }
+
+ /** Constructs a message instance from long keys and double values. */
+ public Message(int workerId, int serverId, int stageId, long[] keys, double[] values) {
+ int sizeInBytes =
+ KVS_OFFSET
+ + Bits.getLongArraySizeInBytes(keys)
+ + Bits.getDoubleArraySizeInBytes(values);
+ bytes = new byte[sizeInBytes];
+ Bits.putInt(bytes, WORKER_ID_OFFSET, workerId);
+ Bits.putInt(bytes, SERVER_ID_OFFSET, serverId);
+ Bits.putInt(bytes, STAGE_ID_OFFSET, stageId);
+ int offset = Bits.putLongArray(keys, bytes, KVS_OFFSET);
+ Bits.putDoubleArray(values, bytes, offset);
+ }
+
+ /** Constructs a message instance from long keys and generics values. */
+ public Message(
+ int workerId,
+ int serverId,
+ int stageId,
+ long[] keys,
+ V[] values,
+ TypeSerializer serializer)
+ throws IOException {
+ ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ DataOutputViewStreamWrapper dataOutputViewStreamWrapper =
+ new DataOutputViewStreamWrapper(byteArrayOutputStream);
+ dataOutputViewStreamWrapper.writeInt(workerId);
+ dataOutputViewStreamWrapper.writeInt(serverId);
+ dataOutputViewStreamWrapper.writeInt(stageId);
+
+ dataOutputViewStreamWrapper.writeInt(keys.length);
+ for (long key : keys) {
+ dataOutputViewStreamWrapper.writeLong(key);
+ }
+ dataOutputViewStreamWrapper.writeInt(values.length);
+ for (V value : values) {
+ serializer.serialize(value, dataOutputViewStreamWrapper);
+ }
+ bytes = byteArrayOutputStream.toByteArray();
+ }
+
+ /** Retrieves the keys. */
+ public long[] getKeys() {
+ return Bits.getLongArray(bytes, KVS_OFFSET);
+ }
+
+ /** Retrieves the values using the given serializer. */
+ public V[] getValues(TypeSerializer serializer) throws IOException {
+ int numIndices = Bits.getInt(bytes, KVS_OFFSET);
+ int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES;
+ int numValues = Bits.getInt(bytes, offset);
+ offset += Integer.BYTES;
+
+ // Since the generics got erased, we use reflections to create the array.
+ V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues);
+ ByteArrayInputStream byteArrayInputStream =
+ new ByteArrayInputStream(bytes, offset, bytes.length - offset);
+ DataInputViewStreamWrapper dataInputViewStreamWrapper =
+ new DataInputViewStreamWrapper(byteArrayInputStream);
+ for (int i = 0; i < numValues; i++) {
+ result[i] = serializer.deserialize(dataInputViewStreamWrapper);
+ }
+ return result;
+ }
+
+ /**
+ * Retrieves the values in double array.
+ *
+ * Note that getting double array in this function using {@link Bits#getDoubleArray(byte[],
+ * int)} is faster than {@link Message#getValues} by up to 2.3X.
+ */
+ public double[] getValuesInDoubleArray() {
+ int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES;
+ return Bits.getDoubleArray(bytes, offset);
+ }
+
+ /** Retrieves the worker id. */
+ public int getWorkerId() {
+ return Bits.getInt(bytes, WORKER_ID_OFFSET);
+ }
+
+ /** Sets the worker id. */
+ public void setWorkerId(int workerId) {
+ Bits.putInt(bytes, WORKER_ID_OFFSET, workerId);
+ }
+
+ /** Retrieves the server id. */
+ public int getServerId() {
+ return Bits.getInt(bytes, SERVER_ID_OFFSET);
+ }
+
+ /** Sets the server id. */
+ public void setServerId(int serverId) {
+ Bits.putInt(bytes, SERVER_ID_OFFSET, serverId);
+ }
+
+ /** Retrieves the stage id. */
+ public int getStageId() {
+ return Bits.getInt(bytes, STAGE_ID_OFFSET);
+ }
+
+ /**
+ * Assembles the received messages from servers according to the server id. Note that these
+ * message should be the responses from the same stage.
+ */
+ public static Message assembleMessages(Iterator messageIterator) {
+ List messages = new ArrayList<>();
+ while (messageIterator.hasNext()) {
+ messages.add(new Message(messageIterator.next()));
+ }
+ messages.sort(Comparator.comparingInt(Message::getServerId));
+
+ int numMessages = messages.size();
+ int numKeys = 0, numValues = 0;
+ int numAssembledBytes = 0;
+ int workerId = -1;
+ int stageId = -1;
+ for (Message message : messages) {
+ if (workerId == -1) {
+ workerId = message.getWorkerId();
+ stageId = message.getStageId();
+ }
+ numKeys += message.getNumKeys();
+ numValues += message.getNumValues();
+ numAssembledBytes += message.bytes.length;
+ }
+ numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2);
+ byte[] assembledBytes = new byte[numAssembledBytes];
+ Bits.putInt(assembledBytes, WORKER_ID_OFFSET, workerId);
+ Bits.putInt(assembledBytes, STAGE_ID_OFFSET, stageId);
+ int keysOffset = KVS_OFFSET;
+ Bits.putInt(assembledBytes, keysOffset, numKeys);
+ keysOffset += Integer.BYTES;
+ int valuesOffset = keysOffset + numKeys * Long.BYTES;
+ Bits.putInt(assembledBytes, valuesOffset, numValues);
+ valuesOffset += Integer.BYTES;
+
+ for (Message message : messages) {
+ Tuple2 keysOffsetAndLength = message.getKeysOffsetAndLength();
+ System.arraycopy(
+ message.bytes,
+ keysOffsetAndLength.f0,
+ assembledBytes,
+ keysOffset,
+ keysOffsetAndLength.f1);
+ keysOffset += keysOffsetAndLength.f1;
+ Tuple2 valuesOffsetAndLength = message.getValuesOffSetAndLength();
+ System.arraycopy(
+ message.bytes,
+ valuesOffsetAndLength.f0,
+ assembledBytes,
+ valuesOffset,
+ valuesOffsetAndLength.f1);
+ valuesOffset += valuesOffsetAndLength.f1;
+ }
+
+ Message message = new Message(assembledBytes);
+ message.setServerId(-1);
+ return message;
+ }
+
+ private Tuple2 getKeysOffsetAndLength() {
+ int start = KVS_OFFSET + Integer.BYTES;
+ int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES;
+ return Tuple2.of(start, numBytes);
+ }
+
+ private Tuple2 getValuesOffSetAndLength() {
+ int start =
+ Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES
+ + KVS_OFFSET
+ + Integer.BYTES
+ + Integer.BYTES;
+ return Tuple2.of(start, bytes.length - start);
+ }
+
+ private int getNumKeys() {
+ return Bits.getInt(bytes, KVS_OFFSET);
+ }
+
+ private int getNumValues() {
+ return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys());
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java
new file mode 100644
index 000000000..d8d3c095c
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray;
+import org.apache.flink.ml.common.ps.sarray.SharedLongArray;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.function.Function;
+
+/**
+ * ServerAgent resides on each worker. It serves as an agent for {@link WorkerOperator} to talk with
+ * {@link ServerOperator}.
+ */
+class ServerAgent {
+ /** Index of the worker that this agent resides on. */
+ private final int workerId;
+ /** Number of servers to talk to. */
+ private final int numServers;
+ /** Hash function to partition keys to different servers. */
+ private final Function hashFunc;
+ /** The collector on this worker. */
+ private final Output> output;
+
+ ServerAgent(
+ int workerId,
+ int numServers,
+ Function hashFunc,
+ Output> output) {
+ this.workerId = workerId;
+ this.numServers = numServers;
+ this.output = output;
+ this.hashFunc = hashFunc;
+ }
+
+ /** Pushes a key-value arrays to servers. */
+ void push(SharedLongArray keys, SharedDoubleArray values, int stageId) {
+ Tuple2 slicedRequests = sliceRequest(keys, values);
+ LongArrayList[] splitKeys = slicedRequests.f0;
+ DoubleArrayList[] splitValues = slicedRequests.f1;
+ for (int serverId = 0; serverId < splitKeys.length; serverId++) {
+ Message message =
+ new Message(
+ workerId,
+ serverId,
+ stageId,
+ splitKeys[serverId].toLongArray(),
+ splitValues[serverId].toDoubleArray());
+ output.collect(new StreamRecord<>(message.bytes));
+ }
+ }
+
+ /** Pulls the values from servers with the specified keys. */
+ void pull(SharedLongArray keys, int stageId) {
+ Tuple2 slicedRequests = sliceRequest(keys, null);
+ LongArrayList[] splitKeys = slicedRequests.f0;
+ for (int serverId = 0; serverId < splitKeys.length; serverId++) {
+ Message message =
+ new Message(
+ workerId,
+ serverId,
+ stageId,
+ splitKeys[serverId].toLongArray(),
+ new double[0]);
+ output.collect(new StreamRecord<>(message.bytes));
+ }
+ }
+
+ /**
+ * Pushes the values to servers to apply all-reduce/reduce-scatter operation.
+ *
+ * Note that the values pushed by this function are not going to update the model, but just
+ * perform an reduce operation.
+ */
+ void reduce(V[] values, TypeSerializer typeSerializer, int stageId) throws IOException {
+ int shardSize = values.length / numServers + 1;
+ for (int serverId = 0; serverId < numServers; serverId++) {
+ int s = Math.min(serverId * shardSize, values.length);
+ int e = Math.min(s + shardSize, values.length);
+ V[] segment = Arrays.copyOfRange(values, s, e);
+ Message message =
+ new Message(workerId, serverId, stageId, new long[0], segment, typeSerializer);
+ output.collect(new StreamRecord<>(message.bytes));
+ }
+ }
+
+ /**
+ * Splits the push/pull request according to the given sorted keys and the corresponding values.
+ *
+ * @param keys keys of push/pull request.
+ * @param values the push values if not null.
+ * @return the split requests for each server.
+ */
+ private Tuple2 sliceRequest(
+ SharedLongArray keys, @Nullable SharedDoubleArray values) {
+ LongArrayList[] splitKeys = new LongArrayList[numServers];
+ DoubleArrayList[] splitValues = new DoubleArrayList[numServers];
+ for (int i = 0; i < numServers; i++) {
+ splitKeys[i] = new LongArrayList();
+ splitValues[i] = new DoubleArrayList();
+ }
+
+ int numDoublesPerKey = 0;
+ if (values != null) {
+ Preconditions.checkState(
+ values.size() % keys.size() == 0, "The length of each key should be the same.");
+ numDoublesPerKey = values.size() / keys.size();
+ }
+
+ long[] keyArray = keys.elements();
+ for (int i = 0; i < keys.size(); i++) {
+ int serverId = hashFunc.apply(keyArray[i]);
+ splitKeys[serverId].add(keyArray[i]);
+ if (values != null) {
+ for (int j = 0; j < numDoublesPerKey; j++) {
+ splitValues[serverId].add(values.get(i * numDoublesPerKey + j));
+ }
+ }
+ }
+
+ return Tuple2.of(splitKeys, splitValues);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java
new file mode 100644
index 000000000..4edbbf0c4
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java
@@ -0,0 +1,534 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.ml.common.ps.iterations.AllReduceStage;
+import org.apache.flink.ml.common.ps.iterations.IterationStage;
+import org.apache.flink.ml.common.ps.iterations.PullStage;
+import org.apache.flink.ml.common.ps.iterations.PullStage.Aggregator;
+import org.apache.flink.ml.common.ps.iterations.PushStage;
+import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage;
+import org.apache.flink.ml.common.ps.updater.ModelUpdater;
+import org.apache.flink.ml.util.Bits;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
+import it.unimi.dsi.fastutil.objects.ObjectIterator;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingDeque;
+
+/**
+ * The server operator maintains the shared parameters. The shared parameters can be modeled as a
+ * collection of {key:value} pairs. By default, the keys are evenly distributed across servers
+ * through hash partitioning. For example, if there are two servers and the keys are {1,2,3,4,5,6},
+ * then server-0 maintains keys {1,3,5} and server-1 maintains keys {2,4,6}.
+ *
+ * The server receives push/pull/all-reduce/reduce-scatter requests from {@link WorkerOperator}
+ * and sends the answer request to {@link WorkerOperator}. It works closely with {@link
+ * ModelUpdater} in the following way:
+ *
+ *
+ * - The server operator deals with the message from workers and decides when to process the
+ * received message.
+ *
- The server operator calls {@link ModelUpdater#update(long[], double[])} and {@link
+ * ModelUpdater#get(long[])} to process the messages in detail.
+ *
- The server operator triggers checkpoint for {@link ModelUpdater}.
+ *
- The server operator outputs the final output parameters by calling {@link
+ * ModelUpdater#getModelSegments()}.
+ *
+ *
+ * Moreover, it accepts all-reduce/reduce-scatter request from workers and returns the reduced
+ * result to all workers. Note that the input of all-reduce/reduce-scatter operation is not going to
+ * be used in {@link ModelUpdater}.
+ *
+ * @param output format of model data.
+ */
+public class ServerOperator extends AbstractStreamOperator
+ implements OneInputStreamOperator, IterationListener {
+ /** The iterationStage list. */
+ private final List stageList;
+ /** Number of workers to communicate with. */
+ private final int numWorkers;
+ /** The logic to answer push/pull request from workers. */
+ private final ModelUpdater modelUpdater;
+ /** Output tag of model data. */
+ private final OutputTag modelOutputTag;
+ /** Index of the current server task. */
+ private transient int serverId;
+ /** Thread pool to answer push/pull requests, to decouple the network and computation. */
+ private transient ExecutorService singleThreadExecutor;
+ /** The future objects of thread calls in one epoch. */
+ private transient List> futuresInEpoch;
+ /**
+ * The pending requests that server needs to send out responses (pull, all-reduce,
+ * reduce-scatter).
+ */
+ private ListState pendingRequests;
+ /**
+ * The push request merged by stage id. We use map to store the merged push request since there
+ * may be consecutive pushes.
+ */
+ private transient TreeMap accPushesByStage;
+
+ private ListState accPushesByStageState;
+
+ public ServerOperator(
+ List stageList,
+ int numWorkers,
+ ModelUpdater modelUpdater,
+ OutputTag modelOutputTag) {
+ this.stageList = stageList;
+ this.numWorkers = numWorkers;
+ this.modelUpdater = modelUpdater;
+ this.modelOutputTag = modelOutputTag;
+ }
+
+ @Override
+ public void open() throws Exception {
+ super.open();
+ this.serverId = getRuntimeContext().getIndexOfThisSubtask();
+ this.singleThreadExecutor = Executors.newSingleThreadExecutor();
+ this.futuresInEpoch = new ArrayList<>();
+ }
+
+ @Override
+ public void processElement(StreamRecord element) throws Exception {
+ Message message = new Message(element.getValue());
+ IterationStage stage = stageList.get(message.getStageId() % stageList.size());
+ if (stage instanceof PushStage) {
+ futuresInEpoch.add(singleThreadExecutor.submit(() -> processPushRequest(message)));
+ } else if (stage instanceof PullStage
+ || stage instanceof AllReduceStage
+ || stage instanceof ReduceScatterStage) {
+ pendingRequests.add(message.bytes);
+ } else {
+ throw new IllegalStateException(
+ "Illegal iteration stage: " + stage.getClass().getSimpleName() + ".");
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void onEpochWatermarkIncremented(
+ int epochWatermark, Context context, Collector collector) throws Exception {
+ // Waits until the pushes are processed.
+ for (Future> future : futuresInEpoch) {
+ future.get();
+ }
+ futuresInEpoch.clear();
+ // Uses the merged pushes to update model.
+ for (Long2ObjectOpenHashMap currentAccPush : accPushesByStage.values()) {
+ if (currentAccPush.size() > 0) {
+ // The push is not empty.
+ int numDoublesPerKey;
+ Object object = currentAccPush.values().iterator().next();
+ if (object instanceof Double) {
+ numDoublesPerKey = 1;
+ } else {
+ numDoublesPerKey = ((double[]) object).length;
+ }
+
+ ObjectIterator> objectIterator =
+ currentAccPush.long2ObjectEntrySet().fastIterator();
+
+ long[] assembledKeys = new long[currentAccPush.size()];
+ double[] assembledValues = new double[currentAccPush.size() * numDoublesPerKey];
+
+ int idx = 0;
+ if (numDoublesPerKey == 1) {
+ while (objectIterator.hasNext()) {
+ Map.Entry entry =
+ (Map.Entry) objectIterator.next();
+ assembledKeys[idx] = entry.getKey();
+ assembledValues[idx] = entry.getValue();
+ idx++;
+ }
+ } else {
+ while (objectIterator.hasNext()) {
+ Map.Entry entry =
+ (Map.Entry) objectIterator.next();
+ assembledKeys[idx] = entry.getKey();
+ System.arraycopy(
+ entry.getValue(),
+ 0,
+ assembledValues,
+ idx * numDoublesPerKey,
+ numDoublesPerKey);
+ idx++;
+ }
+ }
+ currentAccPush.clear();
+ modelUpdater.update(assembledKeys, assembledValues);
+ }
+ }
+
+ // Deals with the pending requests, which should be one of Pull, AllReduce, ReduceScatter.
+ Iterator requestIterator = pendingRequests.get().iterator();
+ if (requestIterator.hasNext()) {
+ Message message = new Message(requestIterator.next());
+ int stageId = message.getStageId();
+ IterationStage stage = stageList.get(stageId % stageList.size());
+ requestIterator = pendingRequests.get().iterator();
+ if (stage instanceof PullStage) {
+ final int blockingQueueCapacity = 20;
+ LinkedBlockingDeque pullsResponse =
+ new LinkedBlockingDeque<>(blockingQueueCapacity);
+ for (byte[] bytes : pendingRequests.get()) {
+ singleThreadExecutor.submit(
+ () -> processPullRequest(new Message(bytes), pullsResponse));
+ }
+ int numResponsesSent = 0;
+ while (numResponsesSent < numWorkers) {
+ Message response = new Message(pullsResponse.take());
+ output.collect(new StreamRecord<>(response.bytes));
+ numResponsesSent++;
+ }
+ } else if (stage instanceof AllReduceStage) {
+ processAllReduceRequest(requestIterator);
+ } else if (stage instanceof ReduceScatterStage) {
+ processReduceScatterRequest(requestIterator);
+ } else {
+ throw new IllegalStateException(
+ "Illegal iteration stage: " + stage.getClass().getSimpleName() + ".");
+ }
+
+ pendingRequests.clear();
+ }
+ }
+
+ @Override
+ public void onIterationTerminated(Context context, Collector collector) {
+ Iterator modelSegments = modelUpdater.getModelSegments();
+ while (modelSegments.hasNext()) {
+ MT modelSegment = modelSegments.next();
+ output.collect(modelOutputTag, new StreamRecord<>(modelSegment));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ pendingRequests =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "pendingRequests",
+ PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO));
+
+ modelUpdater.initializeState(context);
+
+ accPushesByStageState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "accPushesByStageState",
+ PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO));
+
+ // Recovers accPushesByStage from a byte[] stream.
+ Iterator accPushesInBytes = accPushesByStageState.get().iterator();
+ accPushesByStage = new TreeMap<>();
+
+ if (accPushesInBytes.hasNext()) {
+ // 4 bytes for number of stages.
+ byte[] meta = accPushesInBytes.next();
+ int offset = 0;
+ int numberOfStages = Bits.getInt(meta, offset);
+ for (int i = 0; i < numberOfStages; i++) {
+ byte[] oneStageMeta = accPushesInBytes.next();
+ offset = 0;
+ int stageId = Bits.getInt(oneStageMeta, offset);
+ offset += Integer.BYTES;
+ int sizeOfLong2ObjectMap = Bits.getInt(oneStageMeta, offset);
+ offset += Integer.BYTES;
+ int arrayLengthPerObject = Bits.getInt(oneStageMeta, offset);
+ Long2ObjectOpenHashMap pushes;
+ if (arrayLengthPerObject == 0) {
+ pushes = new Long2ObjectOpenHashMap(sizeOfLong2ObjectMap);
+ } else {
+ pushes = new Long2ObjectOpenHashMap(sizeOfLong2ObjectMap);
+ }
+ accPushesByStage.put(stageId, pushes);
+ for (int entryId = 0; entryId < sizeOfLong2ObjectMap; entryId++) {
+ byte[] kvInBytes = accPushesInBytes.next();
+ long key = Bits.getLong(kvInBytes, 0);
+ if (arrayLengthPerObject == 0) {
+ Double value = Bits.getDouble(kvInBytes, Long.BYTES);
+ pushes.put(key, value);
+ } else {
+ double[] value = Bits.getDoubleArray(kvInBytes, Long.BYTES);
+ pushes.put(key, value);
+ }
+ }
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ // Waits until the futures to finish.
+ for (Future> future : futuresInEpoch) {
+ future.get();
+ }
+ futuresInEpoch.clear();
+ modelUpdater.snapshotState(context);
+
+ accPushesByStageState.clear();
+ // Writes accPushesByStage to state in the following format:
+ // numberOfStagesInInt,
+ // stageIdInInt, sizeOfLong2ObjectMapInInt, arrayLengthPerObject, key-value-long-obj...
+ // stageIdInInt, sizeOfLong2ObjectMapInInt, arrayLengthPerObject, key-value-long-obj...
+ if (accPushesByStage.size() > 0) {
+ int numberOfStages = accPushesByStage.size();
+ byte[] meta = new byte[Integer.BYTES];
+ Bits.putInt(meta, 0, numberOfStages);
+ accPushesByStageState.add(meta);
+
+ for (Map.Entry entry : accPushesByStage.entrySet()) {
+ byte[] oneStageMeta = new byte[Integer.BYTES * 3];
+ int offset = 0;
+ int stageId = entry.getKey();
+ Bits.putInt(oneStageMeta, offset, stageId);
+ offset += Integer.BYTES;
+ int sizeOfLong2ObjectMap = entry.getValue().size();
+ Bits.putInt(oneStageMeta, offset, sizeOfLong2ObjectMap);
+ offset += Integer.BYTES;
+ // 0 stands for Double, a non-zero value represents the array length.
+ int arrayLengthPerObject = 0;
+
+ ObjectIterator> objectIterator =
+ entry.getValue().long2ObjectEntrySet().fastIterator();
+
+ if (objectIterator.hasNext()) {
+ Map.Entry oneEntry = objectIterator.next();
+ if (oneEntry.getValue() instanceof double[]) {
+ arrayLengthPerObject = ((double[]) (oneEntry.getValue())).length;
+ }
+ Bits.putInt(oneStageMeta, offset, arrayLengthPerObject);
+ accPushesByStageState.add(oneStageMeta);
+
+ accPushesByStageState.add(kvToBytes(oneEntry));
+ while (objectIterator.hasNext()) {
+ accPushesByStageState.add(kvToBytes(objectIterator.next()));
+ }
+ }
+ }
+ }
+ }
+
+ private static byte[] kvToBytes(Map.Entry obj) {
+ byte[] bytes;
+ if (obj.getValue() instanceof double[]) {
+ double[] value = (double[]) obj.getValue();
+ bytes = new byte[Long.BYTES + Bits.getDoubleArraySizeInBytes(value)];
+ Bits.putLong(bytes, 0, obj.getKey());
+ Bits.putDoubleArray(value, bytes, Long.BYTES);
+ } else {
+ bytes = new byte[Long.BYTES + Double.BYTES];
+ Bits.putLong(bytes, 0, obj.getKey());
+ Bits.putDouble(bytes, Long.BYTES, (Double) obj.getValue());
+ }
+ return bytes;
+ }
+
+ @SuppressWarnings("unchecked")
+ private Object processPushRequest(Message message) throws Exception {
+ long[] keys = message.getKeys();
+ int stageId = message.getStageId();
+ double[] values = message.getValuesInDoubleArray();
+
+ accPushesByStage.putIfAbsent(stageId, new Long2ObjectOpenHashMap());
+ Long2ObjectOpenHashMap currentAccKvs = accPushesByStage.get(stageId);
+
+ if (keys.length != 0) {
+ ReduceFunction reduceFunc =
+ ((PushStage) stageList.get(stageId % stageList.size())).reduceFunc;
+ if (values.length == keys.length) {
+ for (int i = 0; i < keys.length; i++) {
+ if (currentAccKvs.containsKey(keys[i])) {
+ double currentVal = (Double) currentAccKvs.get(keys[i]);
+ currentAccKvs.put(keys[i], reduceFunc.reduce(currentVal, values[i]));
+ } else {
+ currentAccKvs.put(keys[i], (Double) values[i]);
+ }
+ }
+ } else {
+ int numDoublesPerKey = values.length / keys.length;
+ for (int i = 0; i < keys.length; i++) {
+ if (currentAccKvs.containsKey(keys[i])) {
+ double[] currentVal = (double[]) currentAccKvs.get(keys[i]);
+ for (int j = 0; j < numDoublesPerKey; j++) {
+ currentVal[j] =
+ reduceFunc.reduce(
+ currentVal[j], values[i * numDoublesPerKey + j]);
+ }
+ } else {
+ currentAccKvs.put(
+ keys[i],
+ Arrays.copyOfRange(
+ values,
+ i * numDoublesPerKey,
+ i * numDoublesPerKey + numDoublesPerKey));
+ }
+ }
+ }
+ }
+ return new Object();
+ }
+
+ private void processPullRequest(Message message, LinkedBlockingDeque pullsResponse) {
+ int workerId = message.getWorkerId();
+ long[] keys = message.getKeys();
+ Message response;
+
+ if (keys.length == 0) {
+ // No request on this server.
+ response =
+ new Message(
+ workerId, serverId, message.getStageId(), new long[0], new double[0]);
+ } else {
+ double[] pulledValues = modelUpdater.get(keys);
+ Preconditions.checkState(pulledValues.length % keys.length == 0);
+ int numDoublesPerKey = pulledValues.length / keys.length;
+
+ double[] aggregatedPullValues = null;
+ Aggregator aggregator =
+ ((PullStage) (stageList.get(message.getStageId() % stageList.size())))
+ .aggregator;
+ if (aggregator != null) {
+ // Processes the pulled values if the aggregator is not null.
+ double[] tmp = new double[numDoublesPerKey];
+ for (int i = 0; i < keys.length; i++) {
+ System.arraycopy(pulledValues, i * numDoublesPerKey, tmp, 0, numDoublesPerKey);
+ aggregatedPullValues = aggregator.add(tmp, aggregatedPullValues);
+ }
+ } else {
+ aggregatedPullValues = pulledValues;
+ }
+
+ response =
+ new Message(
+ workerId,
+ serverId,
+ message.getStageId(),
+ new long[0],
+ aggregatedPullValues);
+ }
+ while (!pullsResponse.offer(response.bytes)) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private void processAllReduceRequest(Iterator requests) throws Exception {
+ byte[] request = requests.next();
+ Message message = new Message(request);
+ int stageId = message.getStageId();
+ AllReduceStage stage = (AllReduceStage) stageList.get(stageId % stageList.size());
+ V[] reducedResult = message.getValues(stage.typeSerializer);
+ ReduceFunction reduceFunction = stage.reducer;
+
+ while (requests.hasNext()) {
+ message = new Message(requests.next());
+ reducedResult =
+ reduceFunction.reduce(message.getValues(stage.typeSerializer), reducedResult);
+ }
+ message =
+ new Message(
+ -1, serverId, stageId, new long[0], reducedResult, stage.typeSerializer);
+
+ for (int workerId = 0; workerId < numWorkers; workerId++) {
+ message.setWorkerId(workerId);
+ output.collect(new StreamRecord<>(message.bytes));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private void processReduceScatterRequest(Iterator requests) throws Exception {
+ byte[] request = requests.next();
+ Message message = new Message(request);
+ int stageId = message.getStageId();
+ ReduceScatterStage stage =
+ (ReduceScatterStage) stageList.get(stageId % stageList.size());
+ V[] reducedResult = message.getValues(stage.typeSerializer);
+ ReduceFunction reduceFunction = stage.reducer;
+
+ while (requests.hasNext()) {
+ message = new Message(requests.next());
+ reducedResult =
+ reduceFunction.reduce(message.getValues(stage.typeSerializer), reducedResult);
+ }
+
+ int[] recvCounts = stage.recvCounts;
+ int totalCnt = Arrays.stream(recvCounts).sum();
+ int shardSize = totalCnt / getRuntimeContext().getNumberOfParallelSubtasks() + 1;
+ int sliceStart = Math.min(serverId * shardSize, totalCnt);
+ int sliceEnd = Math.min(sliceStart + shardSize, totalCnt);
+
+ int s = 0;
+ int e;
+ for (int workerId = 0; workerId < numWorkers; workerId++) {
+ e = recvCounts[workerId] + s;
+
+ int intersectionStart = Math.max(s, sliceStart);
+ int interSectionEnd = Math.min(e, sliceEnd);
+ int copyStart = 0, copyEnd = 0;
+ if (interSectionEnd > intersectionStart) {
+ copyStart = intersectionStart - sliceStart;
+ copyEnd = interSectionEnd - sliceStart;
+ }
+ message =
+ new Message(
+ workerId,
+ serverId,
+ stageId,
+ new long[0],
+ Arrays.copyOfRange(reducedResult, copyStart, copyEnd),
+ stage.typeSerializer);
+ output.collect(new StreamRecord<>(message.bytes));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java
new file mode 100644
index 000000000..cf935a6bc
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.ps.iterations.AllReduceStage;
+import org.apache.flink.ml.common.ps.iterations.IterationStage;
+import org.apache.flink.ml.common.ps.iterations.IterationStageList;
+import org.apache.flink.ml.common.ps.iterations.MLSession;
+import org.apache.flink.ml.common.ps.iterations.ProcessStage;
+import org.apache.flink.ml.common.ps.iterations.PullStage;
+import org.apache.flink.ml.common.ps.iterations.PushStage;
+import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage;
+import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray;
+import org.apache.flink.ml.common.ps.sarray.SharedLongArray;
+import org.apache.flink.ml.common.ps.utils.ProxySideOutput;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+import java.util.Iterator;
+import java.util.function.Function;
+
+/**
+ * The worker operator that executes the iterative machine learning process following {@link
+ * IterationStageList}.
+ *
+ * In detail, the worker operator is responsible for the following:
+ *
+ *
+ * - Caches the training data.
+ *
- Initializes the {@link MLSession}.
+ *
- Splits the {@link IterationStageList} by {@link PullStage}, {@link AllReduceStage} and
+ * {@link ReduceScatterStage} into multiple sequences and map it into flink-ml-iterations.
+ *
- Executes the process function in each {@link ProcessStage}.
+ *
- Executes the push/pull/all-reduce/reduce-scatter request in {@link PushStage}, {@link
+ * PullStage}, {@link AllReduceStage} and {@link ReduceScatterStage}. which talks to servers,
+ * by reading/writing {@link MLSession}.
+ *
+ */
+public class WorkerOperator extends AbstractStreamOperator
+ implements TwoInputStreamOperator, IterationListener {
+ /** The user defined iteration logic. */
+ private final IterationStageList iterationStages;
+ /**
+ * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages
+ * in stageList, the iteration id increments by one.
+ */
+ private int iterationId;
+
+ /** The id of the stages to execute in iterationStages. */
+ private int nextStageToExecute = 0;
+
+ private ListState nextStageToExecuteState;
+
+ /** The agent for each worker to talk with servers. */
+ private transient ServerAgent serverAgent;
+ /** Number of servers that this worker needs to talk to. */
+ private final int numServers;
+ /** The hash function to distribute keys to servers. */
+ private transient Function hashFunc;
+
+ /** The cached training data. */
+ private ListStateWithCache trainDataState;
+
+ /**
+ * Number of segments received from servers for the current request. For each request, a worker
+ * should receive one segment from each server.
+ */
+ private int numSegmentsReceived = 0;
+
+ private ListState numSegmentsReceivedState;
+
+ /**
+ * The memory store for pull answer. For a pull request, each received segment will be filled to
+ * the user provided buffer.
+ */
+ private double[] pulledResult;
+
+ private ListState pulledResultState;
+
+ /** The state store for the all-reduce/reduce-scatter results. */
+ private ListState reducedResult;
+
+ public WorkerOperator(IterationStageList iterationStages, int numServers) {
+ this.iterationStages = iterationStages;
+ this.numServers = numServers;
+ }
+
+ @Override
+ public void open() {
+ int workerId = getRuntimeContext().getIndexOfThisSubtask();
+ int numWorkers = getRuntimeContext().getNumberOfParallelSubtasks();
+ this.hashFunc = key -> (int) (Math.abs(key % numServers));
+ this.serverAgent = new ServerAgent(workerId, numServers, hashFunc, output);
+ iterationStages.session.setWorldInfo(workerId, numWorkers);
+ iterationStages.session.setOutput(new ProxySideOutput(output));
+ }
+
+ @Override
+ public void onEpochWatermarkIncremented(
+ int epochWatermark, Context context, Collector collector) throws Exception {
+ if (epochWatermark == 0) {
+ iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState));
+ nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages);
+ }
+ }
+
+ @Override
+ public void onIterationTerminated(Context context, Collector collector) {
+ trainDataState.clear();
+ }
+
+ @Override
+ public void processElement1(StreamRecord streamRecord) throws Exception {
+ trainDataState.add(streamRecord.getValue());
+ }
+
+ @Override
+ public void processElement2(StreamRecord streamRecord) throws Exception {
+ Message message = new Message(streamRecord.getValue());
+ IterationStage stage =
+ iterationStages.stageList.get(
+ nextStageToExecute % iterationStages.stageList.size());
+
+ boolean proceedToNextStage;
+ if (stage instanceof PullStage) {
+ proceedToNextStage = onPullResponse(message, (PullStage) stage);
+ } else if (stage instanceof AllReduceStage) {
+ proceedToNextStage = onAllReduceResponse(message, (AllReduceStage>) stage);
+ } else if (stage instanceof ReduceScatterStage) {
+ proceedToNextStage = onReduceScatterResponse(message, (ReduceScatterStage>) stage);
+ } else {
+ throw new IllegalStateException(
+ "Illegal stage type: %s" + stage.getClass().getSimpleName() + ".");
+ }
+
+ if (proceedToNextStage) {
+ nextStageToExecute++;
+ nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages);
+ }
+ }
+
+ private boolean onPullResponse(Message message, PullStage pullStage) {
+ numSegmentsReceived++;
+ double[] segment = message.getValuesInDoubleArray();
+ if (segment.length != 0) {
+ if (pullStage.aggregator != null) {
+ if (pulledResult.length == 0) {
+ pulledResult = segment;
+ } else {
+ pulledResult = pullStage.aggregator.merge(segment, pulledResult);
+ }
+ } else {
+ SharedLongArray keys = pullStage.keys.get();
+ SharedDoubleArray values = pullStage.values.get();
+ int serverId = message.getServerId();
+ long[] keysArray = keys.elements();
+
+ if (pulledResult.length == 0) {
+ pulledResult = values.elements();
+ }
+
+ int numDoublesPerKey = values.size() / keys.size();
+ // Copy the response from one server to the result array.
+ int idxInLocalPull = 0;
+ for (int i = 0; i < keys.size(); i++) {
+ if (hashFunc.apply(keysArray[i]) == serverId) {
+ System.arraycopy(
+ segment,
+ idxInLocalPull * numDoublesPerKey,
+ pulledResult,
+ i * numDoublesPerKey,
+ numDoublesPerKey);
+ idxInLocalPull++;
+ }
+ }
+ }
+ }
+
+ if (numSegmentsReceived == numServers) {
+ SharedDoubleArray pullPlaceHolder = pullStage.values.get();
+ System.arraycopy(
+ pulledResult, 0, pullPlaceHolder.elements(), 0, pullPlaceHolder.size());
+
+ pulledResult = new double[0];
+ numSegmentsReceived = 0;
+ return true;
+ }
+ return false;
+ }
+
+ private boolean onAllReduceResponse(Message message, AllReduceStage allReduceStage)
+ throws Exception {
+ numSegmentsReceived++;
+ reducedResult.add(message.bytes);
+
+ if (numSegmentsReceived == numServers) {
+ Message assembled = Message.assembleMessages(reducedResult.get().iterator());
+ V[] reduceResult = assembled.getValues(allReduceStage.typeSerializer);
+ System.arraycopy(reduceResult, 0, allReduceStage.recvBuf.get(), 0, reduceResult.length);
+ reducedResult.clear();
+ numSegmentsReceived = 0;
+ return true;
+ }
+ return false;
+ }
+
+ private boolean onReduceScatterResponse(
+ Message message, ReduceScatterStage reduceScatterStage) throws Exception {
+ numSegmentsReceived++;
+ reducedResult.add(message.bytes);
+
+ if (numSegmentsReceived == numServers) {
+ Message assembled = Message.assembleMessages(reducedResult.get().iterator());
+ V[] reduceResult = assembled.getValues(reduceScatterStage.typeSerializer);
+ System.arraycopy(
+ reduceResult, 0, reduceScatterStage.recvBuf.get(), 0, reduceResult.length);
+ reducedResult.clear();
+ numSegmentsReceived = 0;
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ trainDataState =
+ new ListStateWithCache<>(
+ (getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader())),
+ getContainingTask(),
+ getRuntimeContext(),
+ context,
+ config.getOperatorID());
+
+ numSegmentsReceivedState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>("numSegmentsReceivedState", Types.INT));
+ numSegmentsReceived =
+ OperatorStateUtils.getUniqueElement(
+ numSegmentsReceivedState, "numSegmentsReceivedState")
+ .orElse(0);
+
+ nextStageToExecuteState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>("nextStageToExecuteState", Types.INT));
+
+ nextStageToExecute =
+ OperatorStateUtils.getUniqueElement(
+ nextStageToExecuteState, "nextStageToExecuteState")
+ .orElse(0);
+
+ iterationStages.session.initializeState(context);
+
+ pulledResultState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "pulledResultState",
+ PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO));
+ pulledResult =
+ OperatorStateUtils.getUniqueElement(pulledResultState, "pulledResultState")
+ .orElse(new double[0]);
+
+ reducedResult =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "reducedResult",
+ PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO));
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+
+ numSegmentsReceivedState.clear();
+ numSegmentsReceivedState.add(numSegmentsReceived);
+
+ nextStageToExecuteState.clear();
+ nextStageToExecuteState.add(nextStageToExecute);
+
+ trainDataState.snapshotState(context);
+ iterationStages.session.snapshotState(context);
+
+ pulledResultState.clear();
+ pulledResultState.add(pulledResult);
+ }
+
+ /**
+ * Processes the stages described in the given iterationStages from the given nextStage id. This
+ * function processes the stages until it meets a {@link PullStage}, {@link AllReduceStage} or
+ * {@link ReduceScatterStage}.
+ *
+ * @param nextStageToExecute id of the next stage to execute in the given iteration stages.
+ * @param iterationStages iteration stages used to describe the training logic.
+ * @return the id of the next stage to execute.
+ */
+ @SuppressWarnings("unchecked")
+ private int processIterationStages(
+ int nextStageToExecute, IterationStageList iterationStages) throws Exception {
+ while (true) {
+ if (nextStageToExecute > 0
+ && nextStageToExecute % iterationStages.stageList.size() == 0) {
+ iterationId = nextStageToExecute / iterationStages.stageList.size();
+ iterationStages.session.setIterationId(iterationId);
+ if (iterationStages.shouldTerminate.apply(iterationStages.session)) {
+ return -1;
+ }
+ }
+ IterationStage stage =
+ iterationStages.stageList.get(
+ nextStageToExecute % iterationStages.stageList.size());
+
+ // We are not incrementing nextStageToExecute for
+ // PullStage/AllReduceStage/ReduceScatterStage, since we
+ // need to wait for response from servers.
+ if (stage instanceof PullStage) {
+ PullStage pullStage = ((PullStage) stage);
+ serverAgent.pull(pullStage.keys.get(), nextStageToExecute);
+ return nextStageToExecute;
+
+ } else if (stage instanceof AllReduceStage) {
+ AllReduceStage allReduceStage = (AllReduceStage) stage;
+ if (iterationId % allReduceStage.executionInterval == 0) {
+ serverAgent.reduce(
+ allReduceStage.sendBuf.get(),
+ allReduceStage.typeSerializer,
+ nextStageToExecute);
+ return nextStageToExecute;
+ } else {
+ nextStageToExecute++;
+ }
+
+ } else if (stage instanceof ReduceScatterStage) {
+ ReduceScatterStage reduceScatterStage = (ReduceScatterStage) stage;
+ if (iterationId % reduceScatterStage.executionInterval == 0) {
+ serverAgent.reduce(
+ reduceScatterStage.sendBuf.get(),
+ reduceScatterStage.typeSerializer,
+ nextStageToExecute);
+ return nextStageToExecute;
+ } else {
+ nextStageToExecute++;
+ }
+ } else if (stage instanceof PushStage) {
+ PushStage pushStage = (PushStage) stage;
+ serverAgent.push(pushStage.keys.get(), pushStage.values.get(), nextStageToExecute);
+ nextStageToExecute++;
+ } else if (stage instanceof ProcessStage) {
+ ((ProcessStage) stage).process(iterationStages.session);
+ nextStageToExecute++;
+ } else {
+ throw new IllegalStateException(
+ "Illegal type of IterationStage: + "
+ + stage.getClass().getSimpleName()
+ + ".");
+ }
+ }
+ }
+
+ /** A resettable iterator for {@link ListStateWithCache}. */
+ private static class ResettableTrainDataIterator implements ResettableIterator {
+ private final ListStateWithCache data;
+ private Iterator dataIterator;
+
+ public ResettableTrainDataIterator(ListStateWithCache data) throws Exception {
+ this.data = data;
+ this.dataIterator = data.get().iterator();
+ }
+
+ @Override
+ public void reset() {
+ try {
+ this.dataIterator = data.get().iterator();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return dataIterator.hasNext();
+ }
+
+ @Override
+ public T next() {
+ return dataIterator.next();
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java
new file mode 100644
index 000000000..c0d3cdf0b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import java.util.function.Supplier;
+
+/**
+ * This iteration stage is designed to perform an all-reduce operation on the specified array in a
+ * distributed setting.
+ *
+ * Users can specify how often this operation is conducted by setting the value of the
+ * "executionInterval" parameter, which determines the frequency of the all-reduce stage. For
+ * example, if the value of executionInterval is set to 5, the all-reduce stage will be executed
+ * every 5 iterations.
+ */
+public final class AllReduceStage implements IterationStage {
+ public final Supplier sendBuf;
+ public final Supplier recvBuf;
+ public final ReduceFunction reducer;
+ public final TypeSerializer typeSerializer;
+ public final int executionInterval;
+
+ public AllReduceStage(
+ SerializableSupplier sendBuf,
+ SerializableSupplier recvBuf,
+ ReduceFunction reducer,
+ TypeSerializer typeSerializer,
+ int executionInterval) {
+ this.sendBuf = sendBuf;
+ this.recvBuf = recvBuf;
+ this.reducer = reducer;
+ this.typeSerializer = typeSerializer;
+ this.executionInterval = executionInterval;
+ }
+
+ public AllReduceStage(
+ SerializableSupplier sendBuf,
+ SerializableSupplier recvBuf,
+ ReduceFunction reducer,
+ TypeSerializer typeSerializer) {
+ this(sendBuf, recvBuf, reducer, typeSerializer, 1);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java
new file mode 100644
index 000000000..d0f23a774
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import java.io.Serializable;
+
+/**
+ * Iterative machine learning training usually incurs local computation step (e.g., computing
+ * gradients) and global communication step (e.g., all-reduce and parameter servers to aggregate the
+ * updates from workers).
+ *
+ * To describe the above iteration training process, we model the training process as a sequence
+ * of iteration stages. An iteration stage could be either local computation or global
+ * communication.
+ */
+public interface IterationStage extends Serializable {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java
new file mode 100644
index 000000000..e1cd23b7b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+
+/**
+ * A list of iteration stages to express the logic of an iterative machine learning process.
+ *
+ *
Note that there should be at least one stage (e.g., {@link PullStage}, {@link AllReduceStage}
+ * or {@link ReduceScatterStage}) that needs to wait for responses from servers.
+ */
+public class IterationStageList implements Serializable {
+ /** The session on each worker. */
+ public final T session;
+ /** The termination criteria. */
+ public Function shouldTerminate;
+ /** The stage list that describes the iterative process. */
+ public List stageList;
+
+ public IterationStageList(T session) {
+ this.stageList = new ArrayList<>();
+ this.session = session;
+ }
+
+ /** Sets the criteria of termination. */
+ public IterationStageList setTerminationCriteria(
+ SerializableFunction shouldTerminate) {
+ boolean waitServer = false;
+ for (IterationStage stage : stageList) {
+ if (stage instanceof PullStage
+ || stage instanceof AllReduceStage
+ || stage instanceof ReduceScatterStage) {
+ waitServer = true;
+ break;
+ }
+ }
+ Preconditions.checkState(
+ waitServer,
+ String.format(
+ "There should be at least one stage that needs to receive response from servers (i.e., %s, %s, %s).\n",
+ PullStage.class.getSimpleName(),
+ AllReduceStage.class.getSimpleName(),
+ ReduceScatterStage.class.getSimpleName()));
+ this.shouldTerminate = shouldTerminate;
+ return this;
+ }
+
+ /** Adds an iteration stage into the stage list. */
+ public IterationStageList addStage(IterationStage stage) {
+ stageList.add(stage);
+ return this;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java
new file mode 100644
index 000000000..799a5cb6a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.ml.common.ps.WorkerOperator;
+import org.apache.flink.ml.common.ps.utils.ProxySideOutput;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.util.OutputTag;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Stores the session information that is alive during the training process on {@link
+ * WorkerOperator}. Note that the session information will be updated by each {@link
+ * IterationStage}.
+ *
+ * Subclasses should take care of the snapshot of object stored in {@link MLSession} if the
+ * object satisfies that: the write-process is followed by a {@link PullStage}/{@link
+ * AllReduceStage}/{@link ReduceScatterStage}, which is later again read by other stages.
+ */
+public interface MLSession extends Serializable {
+ /** Sets the current iteration ID. */
+ default void setIterationId(int iterationId) {}
+
+ /** Sets the worker id and total number of workers. */
+ default void setWorldInfo(int workerId, int numWorkers) {}
+
+ /** Sets the training data. */
+ default void setInputData(ResettableIterator> inputData) {}
+
+ /** Sets the collector that users can output records to downstream tasks. */
+ default void setOutput(ProxySideOutput collector) {}
+
+ /**
+ * Retrieves the output tags from the {@link MLSession} which can be used to output records from
+ * the worker operator.
+ */
+ default List> getOutputTags() {
+ return null;
+ }
+
+ /** Recovers from state. */
+ default void initializeState(StateInitializationContext context) throws Exception {}
+
+ /** Snapshots to state. */
+ default void snapshotState(StateSnapshotContext context) throws Exception {}
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java
new file mode 100644
index 000000000..317f7cb66
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.util.OutputTag;
+
+import java.util.List;
+
+/**
+ * The default implementation of {@link MLSession}.
+ *
+ * @param Data type of input data.
+ */
+public class MLSessionImpl implements MLSession {
+ /** Current iteration id. */
+ public int iterationId;
+ /** Index of this worker. */
+ public int workerId;
+ /** Number of workers in total for this distributed ML job. */
+ public int numWorkers;
+ /** The input data. */
+ public ResettableIterator inputData;
+
+ public List> outputTags;
+
+ /** Constructs an instance with side outputs. */
+ public MLSessionImpl(List> outputTags) {
+ this.outputTags = outputTags;
+ }
+
+ /** Constructs an instance without side outputs. */
+ public MLSessionImpl() {
+ this(null);
+ }
+
+ @Override
+ public List> getOutputTags() {
+ return outputTags;
+ }
+
+ @Override
+ public void setIterationId(int iterationId) {
+ this.iterationId = iterationId;
+ }
+
+ @Override
+ public void setWorldInfo(int workerId, int numWorkers) {
+ this.workerId = workerId;
+ this.numWorkers = numWorkers;
+ }
+
+ @Override
+ public void setInputData(ResettableIterator> inputData) {
+ this.inputData = (ResettableIterator) inputData;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java
new file mode 100644
index 000000000..8c8810699
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+/**
+ * A local computation stage of the training process. The input and output of {@link ProcessStage}
+ * can be accessed via {@link MLSession}.
+ *
+ * @param Type of the training data.
+ */
+public abstract class ProcessStage implements IterationStage {
+ /**
+ * Does a local computation logic using the information from session. Example stages could be
+ * computing gradients.
+ */
+ public abstract void process(T session) throws Exception;
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java
new file mode 100644
index 000000000..8f86c5e5c
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray;
+import org.apache.flink.ml.common.ps.sarray.SharedLongArray;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import java.io.Serializable;
+import java.util.function.Supplier;
+
+/**
+ * An iteration stage that aggregates data from servers using keys as {@code PullStage#keys#get()}
+ * and stores the aggregated values by {@code PullStage#values#get()}.
+ *
+ * If the aggregator is null, we simply pull those values specified by the keys.
+ */
+public final class PullStage implements IterationStage {
+ public final Supplier keys;
+ public final Supplier values;
+ public final Aggregator aggregator;
+
+ public PullStage(
+ SerializableSupplier keys,
+ SerializableSupplier values) {
+ this(keys, values, null);
+ }
+
+ public PullStage(
+ SerializableSupplier keys,
+ SerializableSupplier values,
+ Aggregator aggregator) {
+ this.keys = keys;
+ this.values = values;
+ this.aggregator = aggregator;
+ }
+
+ /**
+ * An Aggregator is used to aggregate a set of input elements into a single accumulator.
+ *
+ * @param The type of the input elements.
+ * @param The type of the accumulator.
+ */
+ @Internal
+ public interface Aggregator extends Serializable {
+
+ /**
+ * Adds a new input element to the given accumulator and returns the updated accumulator.
+ *
+ * @param in The input element to add.
+ * @param acc The accumulator to update.
+ * @return The updated accumulator.
+ */
+ ACC add(IN in, ACC acc);
+
+ /**
+ * Merges two accumulators and returns the result.
+ *
+ * @param acc1 The first accumulator to merge.
+ * @param acc2 The second accumulator to merge.
+ * @return The merged accumulator.
+ */
+ ACC merge(ACC acc1, ACC acc2);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java
new file mode 100644
index 000000000..3abec6190
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray;
+import org.apache.flink.ml.common.ps.sarray.SharedLongArray;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import java.util.function.Supplier;
+
+/**
+ * An iteration stage that push (indices, values) to servers. User can specify how values from
+ * different workers are merged via {@code PushStage#reduceFunc}. By default, the values are summed
+ * from different workers.
+ *
+ * Note that the length of the values array must be divisible by the length of the keys array.
+ * Additionally, each value corresponding to a given key must have the same length. For instance,
+ * considering the keys {1, 4} and values {1,2,3,4,5,6}, the value at index 1 would be {1,2,3}, and
+ * the value at index 4 would be {4,5,6}.
+ */
+public class PushStage implements IterationStage {
+ public final Supplier keys;
+ public final Supplier values;
+
+ /** The function to reduce the pushes from all workers. For gradient descent based methods, */
+ public final ReduceFunction reduceFunc;
+
+ public PushStage(
+ SerializableSupplier keys,
+ SerializableSupplier values) {
+ this(keys, values, Double::sum);
+ }
+
+ public PushStage(
+ SerializableSupplier keys,
+ SerializableSupplier values,
+ ReduceFunction reduceFunc) {
+ this.keys = keys;
+ this.values = values;
+ this.reduceFunc = reduceFunc;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java
new file mode 100644
index 000000000..c660de285
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.iterations;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import java.util.function.Supplier;
+
+/**
+ * This iteration stage is designed to perform an reduce-scatter operation on the specified array in
+ * a distributed setting.
+ *
+ * Users can specify how often this operation is conducted by setting the value of the
+ * "executionInterval" parameter, which determines the frequency of the all-reduce stage. For
+ * example, if the value of executionInterval is set to 5, the all-reduce stage will be executed
+ * every 5 iterations.
+ */
+public final class ReduceScatterStage implements IterationStage {
+ public final Supplier sendBuf;
+ public final Supplier recvBuf;
+ /** The number of elements each worker receives. */
+ public int[] recvCounts;
+
+ public final ReduceFunction reducer;
+ public final TypeSerializer typeSerializer;
+
+ public final int executionInterval;
+
+ public ReduceScatterStage(
+ SerializableSupplier sendBuf,
+ SerializableSupplier recvBuf,
+ int[] recvCounts,
+ ReduceFunction reducer,
+ TypeSerializer typeSerializer,
+ int executionInterval) {
+ this.sendBuf = sendBuf;
+ this.recvBuf = recvBuf;
+ this.recvCounts = Preconditions.checkNotNull(recvCounts);
+ this.reducer = reducer;
+ this.typeSerializer = typeSerializer;
+ this.executionInterval = executionInterval;
+ }
+
+ public ReduceScatterStage(
+ SerializableSupplier sendBuf,
+ SerializableSupplier recvBuf,
+ int[] recvCounts,
+ ReduceFunction reducer,
+ TypeSerializer typeSerializer) {
+ this(sendBuf, recvBuf, recvCounts, reducer, typeSerializer, 1);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java
new file mode 100644
index 000000000..4a7aa24b0
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.sarray;
+
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+
+import java.io.Serializable;
+
+/** A resizable double array that can be shared among different iterations for memory efficiency. */
+public class SharedDoubleArray implements Serializable {
+
+ /** The underlying DoubleArrayList that holds the elements. */
+ private final DoubleArrayList doubles;
+
+ /**
+ * Constructs a new SDArray from the given double array.
+ *
+ * @param array the double array to wrap
+ */
+ public SharedDoubleArray(double[] array) {
+ doubles = DoubleArrayList.wrap(array);
+ }
+
+ /**
+ * Constructs a new SDArray with the given initial capacity.
+ *
+ * @param capacity the initial capacity
+ */
+ public SharedDoubleArray(int capacity) {
+ doubles = new DoubleArrayList(capacity);
+ }
+
+ /** Constructs a new empty SDArray. */
+ public SharedDoubleArray() {
+ doubles = new DoubleArrayList();
+ }
+
+ /**
+ * Returns the element at the specified index.
+ *
+ * @param index the index of the element to return
+ * @return the element at the specified index
+ */
+ public double get(int index) {
+ return doubles.getDouble(index);
+ }
+
+ /**
+ * Appends the specified element to the end of this array.
+ *
+ * @param v the element to add
+ */
+ public void add(double v) {
+ doubles.add(v);
+ }
+
+ /**
+ * Appends all the elements from the specified double array to the end of this array.
+ *
+ * @param src the double array to append
+ */
+ public void addAll(double[] src) {
+ int sizeBefore = size();
+ doubles.size(sizeBefore + src.length);
+ System.arraycopy(src, 0, elements(), sizeBefore, src.length);
+ }
+
+ /**
+ * Returns the number of valid elements in this array.
+ *
+ * @return the number of valid elements in this array
+ */
+ public int size() {
+ return doubles.size();
+ }
+
+ /**
+ * Sets the size of the array to the provided size. If the new size is larger than the current
+ * size, the new allocated memory are filled with zero.
+ *
+ * @param size the new size of the array
+ */
+ public void size(int size) {
+ doubles.size(size);
+ }
+
+ /** Clears the elements in this array. Note that the memory is not recycled. */
+ public void clear() {
+ doubles.clear();
+ }
+
+ /**
+ * Returns a double array containing all the elements in this array. Only the first {@link
+ * SharedDoubleArray#size()} elements are valid.
+ *
+ * @return a double array containing the all the elements in this array
+ */
+ public double[] elements() {
+ return doubles.elements();
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java
new file mode 100644
index 000000000..d193890da
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.sarray;
+
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+
+import java.io.Serializable;
+
+/** A resizable long array that can be shared among different iterations for memory efficiency. */
+public class SharedLongArray implements Serializable {
+
+ /** The underlying LongArrayList that holds the elements. */
+ private final LongArrayList longs;
+
+ /**
+ * Constructs a new SLArray from the given long array.
+ *
+ * @param array the long array to wrap
+ */
+ public SharedLongArray(long[] array) {
+ longs = LongArrayList.wrap(array);
+ }
+
+ /**
+ * Constructs a new SLArray with the given initial capacity.
+ *
+ * @param capacity the initial capacity
+ */
+ public SharedLongArray(int capacity) {
+ longs = new LongArrayList(capacity);
+ }
+
+ /** Constructs a new empty SLArray. */
+ public SharedLongArray() {
+ longs = new LongArrayList();
+ }
+
+ /**
+ * Returns the element at the specified index.
+ *
+ * @param index the index of the element to return
+ * @return the element at the specified index
+ */
+ public long get(int index) {
+ return longs.getLong(index);
+ }
+
+ /**
+ * Appends the specified element to the end of this array.
+ *
+ * @param v the element to add
+ */
+ public void add(long v) {
+ longs.add(v);
+ }
+
+ /**
+ * Appends all the elements from the specified long array to the end of this array.
+ *
+ * @param src the long array to append
+ */
+ public void addAll(long[] src) {
+ int sizeBefore = size();
+ longs.size(sizeBefore + src.length);
+ System.arraycopy(src, 0, elements(), sizeBefore, src.length);
+ }
+
+ /**
+ * Returns the number of valid elements in this array.
+ *
+ * @return the number of valid elements in this array
+ */
+ public int size() {
+ return longs.size();
+ }
+
+ /**
+ * Resizes this array to the specified size. Sets the size of the array to the provided size. If
+ * the new size is larger than the current size, the new allocated memory are filled with zero.
+ *
+ * @param size the new size of the array
+ */
+ public void size(int size) {
+ longs.size(size);
+ }
+
+ /** Clears the elements in this array. Note that the memory is not recycled. */
+ public void clear() {
+ longs.clear();
+ }
+
+ /**
+ * Returns a long array containing the valid elements in this array. Only the first {@link
+ * SharedLongArray#size()} elements are valid.
+ *
+ * @return a long array containing the valid elements in this array
+ */
+ public long[] elements() {
+ return longs.elements();
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java
new file mode 100644
index 000000000..3e2d3b920
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+
+import java.io.IOException;
+import java.util.Map;
+
+/** TypeSerializer for {@link Long2DoubleOpenHashMap}. */
+public class Long2DoubleOpenHashMapSerializer extends TypeSerializer {
+
+ public static final Long2DoubleOpenHashMapSerializer INSTANCE =
+ new Long2DoubleOpenHashMapSerializer();
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer duplicate() {
+ return INSTANCE;
+ }
+
+ @Override
+ public Long2DoubleOpenHashMap createInstance() {
+ return new Long2DoubleOpenHashMap();
+ }
+
+ @Override
+ public Long2DoubleOpenHashMap copy(Long2DoubleOpenHashMap from) {
+ return new Long2DoubleOpenHashMap(from);
+ }
+
+ @Override
+ public Long2DoubleOpenHashMap copy(Long2DoubleOpenHashMap from, Long2DoubleOpenHashMap reuse) {
+ return new Long2DoubleOpenHashMap(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(Long2DoubleOpenHashMap map, DataOutputView target) throws IOException {
+ target.writeInt(map.size());
+ for (Map.Entry entry : map.entrySet()) {
+ target.writeLong(entry.getKey());
+ target.writeDouble(entry.getValue());
+ }
+ }
+
+ @Override
+ public Long2DoubleOpenHashMap deserialize(DataInputView source) throws IOException {
+ int numEntries = source.readInt();
+ Long2DoubleOpenHashMap map = new Long2DoubleOpenHashMap(numEntries);
+ for (int i = 0; i < numEntries; i++) {
+ long k = source.readLong();
+ double v = source.readDouble();
+ map.put(k, v);
+ }
+ return map;
+ }
+
+ @Override
+ public Long2DoubleOpenHashMap deserialize(Long2DoubleOpenHashMap reuse, DataInputView source)
+ throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws IOException {
+ int numEntries = source.readInt();
+ target.writeInt(numEntries);
+ for (int i = 0; i < numEntries; ++i) {
+ target.writeLong(source.readLong());
+ target.writeDouble(source.readDouble());
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return 0;
+ }
+
+ @Override
+ public TypeSerializerSnapshot snapshotConfiguration() {
+ return new Long2DoubleOpenHashMapSnapshot();
+ }
+
+ private static final class Long2DoubleOpenHashMapSnapshot
+ extends SimpleTypeSerializerSnapshot {
+ public Long2DoubleOpenHashMapSnapshot() {
+ super(() -> Long2DoubleOpenHashMapSerializer.INSTANCE);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java
new file mode 100644
index 000000000..4acfafdc1
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+
+/** TypeInformation for {@link Long2DoubleOpenHashMap}. */
+public class Long2DoubleOpenHashMapTypeInfo extends TypeInformation {
+
+ public static Long2DoubleOpenHashMapTypeInfo instance = new Long2DoubleOpenHashMapTypeInfo();
+
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 1;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 1;
+ }
+
+ @Override
+ public Class getTypeClass() {
+ return Long2DoubleOpenHashMap.class;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer createSerializer(ExecutionConfig config) {
+ return Long2DoubleOpenHashMapSerializer.INSTANCE;
+ }
+
+ @Override
+ public String toString() {
+ return "Long2DoubleOpenHashMap Type";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return 0;
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return obj instanceof Long2DoubleOpenHashMapTypeInfo;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java
new file mode 100644
index 000000000..12f250d7d
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.typeinfo;
+
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.Preconditions;
+
+import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * TypeSerializer for {@link Long2ObjectOpenHashMap}.
+ *
+ * @param The type of elements in the Long2ObjectOpenHashMap.
+ */
+public class Long2ObjectOpenHashMapSerializer extends TypeSerializer> {
+
+ private final TypeSerializer elementSerializer;
+
+ public Long2ObjectOpenHashMapSerializer(TypeSerializer elementSerializer) {
+ this.elementSerializer = Preconditions.checkNotNull(elementSerializer);
+ }
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer> duplicate() {
+ return new Long2ObjectOpenHashMapSerializer<>(elementSerializer.duplicate());
+ }
+
+ @Override
+ public Long2ObjectOpenHashMap createInstance() {
+ return new Long2ObjectOpenHashMap<>();
+ }
+
+ @Override
+ public Long2ObjectOpenHashMap copy(Long2ObjectOpenHashMap from) {
+ return new Long2ObjectOpenHashMap<>(from);
+ }
+
+ @Override
+ public Long2ObjectOpenHashMap copy(
+ Long2ObjectOpenHashMap from, Long2ObjectOpenHashMap reuse) {
+ return new Long2ObjectOpenHashMap<>(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(Long2ObjectOpenHashMap map, DataOutputView target) throws IOException {
+ target.writeInt(map.size());
+ for (Map.Entry entry : map.entrySet()) {
+ target.writeLong(entry.getKey());
+ elementSerializer.serialize(entry.getValue(), target);
+ }
+ }
+
+ @Override
+ public Long2ObjectOpenHashMap deserialize(DataInputView source) throws IOException {
+ int numEntries = source.readInt();
+ Long2ObjectOpenHashMap map = new Long2ObjectOpenHashMap<>(numEntries);
+ for (int i = 0; i < numEntries; i++) {
+ long k = source.readLong();
+ T v = elementSerializer.deserialize(source);
+ map.put(k, v);
+ }
+ return map;
+ }
+
+ @Override
+ public Long2ObjectOpenHashMap deserialize(
+ Long2ObjectOpenHashMap reuse, DataInputView source) throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws IOException {
+ int numEntries = source.readInt();
+ target.writeInt(numEntries);
+ for (int i = 0; i < numEntries; ++i) {
+ target.writeLong(source.readLong());
+ elementSerializer.copy(source, target);
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ Long2ObjectOpenHashMapSerializer> that = (Long2ObjectOpenHashMapSerializer>) o;
+ return Objects.equals(elementSerializer, that.elementSerializer);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(elementSerializer != null ? elementSerializer.hashCode() : 0);
+ }
+
+ @Override
+ public TypeSerializerSnapshot> snapshotConfiguration() {
+ return new Long2ObjectOpenHashMapSnapshot<>(this);
+ }
+
+ private static final class Long2ObjectOpenHashMapSnapshot
+ extends CompositeTypeSerializerSnapshot<
+ Long2ObjectOpenHashMap, Long2ObjectOpenHashMapSerializer> {
+
+ private static final int CURRENT_VERSION = 1;
+
+ public Long2ObjectOpenHashMapSnapshot() {
+ super(Long2ObjectOpenHashMapSerializer.class);
+ }
+
+ public Long2ObjectOpenHashMapSnapshot(Long2ObjectOpenHashMapSerializer serializer) {
+ super(serializer);
+ }
+
+ @Override
+ protected int getCurrentOuterSnapshotVersion() {
+ return CURRENT_VERSION;
+ }
+
+ @Override
+ protected TypeSerializer>[] getNestedSerializers(
+ Long2ObjectOpenHashMapSerializer outerSerializer) {
+ return new TypeSerializer[] {outerSerializer.elementSerializer};
+ }
+
+ @Override
+ protected Long2ObjectOpenHashMapSerializer createOuterSerializerWithNestedSerializers(
+ TypeSerializer>[] nestedSerializers) {
+ TypeSerializer elementSerializer = (TypeSerializer) nestedSerializers[0];
+ return new Long2ObjectOpenHashMapSerializer<>(elementSerializer);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java
new file mode 100644
index 000000000..d80079cfc
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
+
+import java.util.Objects;
+
+/**
+ * TypeInformation for {@link Long2ObjectOpenHashMap}.
+ *
+ * @param The type of elements in the Long2ObjectOpenHashMap.
+ */
+public class Long2ObjectOpenHashMapTypeInfo extends TypeInformation> {
+
+ private final TypeInformation elementTypeInfo;
+
+ public Long2ObjectOpenHashMapTypeInfo(TypeInformation elementTypeInfo) {
+ this.elementTypeInfo = elementTypeInfo;
+ }
+
+ public TypeInformation getElementTypeInfo() {
+ return elementTypeInfo;
+ }
+
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 1;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 1;
+ }
+
+ @Override
+ public Class> getTypeClass() {
+ return (Class) Long2ObjectOpenHashMap.class;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer> createSerializer(ExecutionConfig config) {
+ return new Long2ObjectOpenHashMapSerializer<>(elementTypeInfo.createSerializer(config));
+ }
+
+ @Override
+ public String toString() {
+ return "Long2ObjectOpenHashMap Type";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ Long2ObjectOpenHashMapTypeInfo that = (Long2ObjectOpenHashMapTypeInfo) obj;
+ return Objects.equals(elementTypeInfo, that.elementTypeInfo);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(elementTypeInfo != null ? elementTypeInfo.hashCode() : 0);
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return obj instanceof Long2ObjectOpenHashMapTypeInfo;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java
new file mode 100644
index 000000000..dbb4dd3ce
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.updater;
+
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to update and retrieve model data.
+ *
+ * Note that model updater should also ensure that model data is robust to failures, by writing
+ * model data to snapshots.
+ *
+ * @param data type of model.
+ */
+public interface ModelUpdater extends Serializable {
+ /** Applies the push to update the model data, e.g., using gradient to update model. */
+ void update(long[] keys, double[] values);
+
+ /** Retrieves the model data of the given keys. */
+ double[] get(long[] keys);
+
+ /**
+ * Returns model segments. The model segments are continuously updated/retrieved by
+ * push/pull(i.e., {@link ModelUpdater#update(long[], double[])} and {@link
+ * ModelUpdater#get(long[])}).
+ */
+ Iterator getModelSegments();
+
+ /** Recovers the model data from state. */
+ void initializeState(StateInitializationContext context) throws Exception;
+
+ /** Snapshots the model data to state. */
+ void snapshotState(StateSnapshotContext context) throws Exception;
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java
new file mode 100644
index 000000000..9cba95d0f
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.utils;
+
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** A collector that can only output using {@link OutputTag}. */
+public final class ProxySideOutput {
+ private final Output> output;
+
+ public ProxySideOutput(Output> output) {
+ this.output = output;
+ }
+
+ public void output(OutputTag outputTag, StreamRecord record) {
+ Preconditions.checkNotNull(outputTag);
+ output.collect(outputTag, record);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java
new file mode 100644
index 000000000..261dba6a8
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.utils;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.ps.Message;
+import org.apache.flink.ml.common.ps.ServerOperator;
+import org.apache.flink.ml.common.ps.WorkerOperator;
+import org.apache.flink.ml.common.ps.iterations.IterationStageList;
+import org.apache.flink.ml.common.ps.iterations.MLSession;
+import org.apache.flink.ml.common.ps.updater.ModelUpdater;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.util.OutputTag;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/** Utility function to describe iterative training process. */
+public final class TrainingUtils {
+ /**
+ * Executes the iterative machine learning logic described in {@link IterationStageList} and
+ * returns the fitted model data as well as the outputs from worker operator. The outputs from
+ * worker operator are specified via {@link MLSession#getOutputTags()}.
+ *
+ * @param inputData the input data.
+ * @param iterationStages the iterative processing logic.
+ * @param modelDataType output type information of model data.
+ * @param modelUpdater the logic to update model on servers.
+ * @param numServers number of servers.
+ * @return the fitted model data as well as the outputs from worker operator. The orders are
+ * {modelData, sideOutputs from workers}. Note that the outputs from workers shares the same
+ * order with the {@link MLSession#getOutputTags()}.
+ * @param type information of input data.
+ * @param type information of the output model data.
+ */
+ public static DataStreamList train(
+ DataStream inputData,
+ IterationStageList extends MLSession> iterationStages,
+ TypeInformation modelDataType,
+ ModelUpdater modelUpdater,
+ int numServers) {
+ DataStream variableStream =
+ inputData.getExecutionEnvironment().fromElements(new byte[0]).filter(x -> false);
+
+ return Iterations.iterateBoundedStreamsUntilTermination(
+ DataStreamList.of(variableStream),
+ ReplayableDataStreamList.notReplay(inputData),
+ IterationConfig.newBuilder().build(),
+ new TrainIterationBody<>(modelUpdater, modelDataType, iterationStages, numServers));
+ }
+
+ /** The iteration implementation for training process. */
+ private static class TrainIterationBody implements IterationBody {
+ private final ModelUpdater modelUpdater;
+ private final TypeInformation modelType;
+ private final IterationStageList extends MLSession> iterationStages;
+ private final int numServers;
+
+ public TrainIterationBody(
+ ModelUpdater modelUpdater,
+ TypeInformation modelType,
+ IterationStageList extends MLSession> iterationStages,
+ int numServers) {
+ this.iterationStages = iterationStages;
+ this.modelType = modelType;
+ this.modelUpdater = modelUpdater;
+ this.numServers = numServers;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public IterationBodyResult process(
+ DataStreamList variableStreams, DataStreamList dataStreams) {
+ DataStream variableStream = variableStreams.get(0);
+ DataStream trainData = dataStreams.get(0);
+ final OutputTag modelDataOutputTag = new OutputTag<>("MODEL_OUTPUT", modelType);
+
+ SingleOutputStreamOperator messageToServer =
+ trainData
+ .connect(variableStream)
+ .transform(
+ "WorkerOp",
+ PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO,
+ new WorkerOperator(iterationStages, numServers));
+ int numWorkers = messageToServer.getParallelism();
+
+ SingleOutputStreamOperator messageToWorker =
+ messageToServer
+ .partitionCustom(
+ (Partitioner)
+ (key, numPartitions) -> key % numPartitions,
+ (KeySelector)
+ value -> new Message(value).getServerId())
+ .transform(
+ "ServerOp",
+ PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO,
+ new ServerOperator<>(
+ iterationStages.stageList,
+ numWorkers,
+ modelUpdater,
+ modelDataOutputTag));
+ messageToWorker.setParallelism(numServers);
+
+ DataStream feedback =
+ messageToWorker
+ .partitionCustom(
+ (Partitioner)
+ (key, numPartitions) -> key % numPartitions,
+ (KeySelector)
+ value -> new Message(value).getWorkerId())
+ .map(
+ (MapFunction) message -> message,
+ PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)
+ .setParallelism(numWorkers);
+
+ DataStream model = messageToWorker.getSideOutput(modelDataOutputTag);
+
+ List> result = new ArrayList<>();
+ result.add(model);
+
+ List> sideOutputTags = iterationStages.session.getOutputTags();
+ if (sideOutputTags != null) {
+ for (OutputTag> outputTag : sideOutputTags) {
+ result.add(messageToServer.getSideOutput(outputTag));
+ }
+ }
+
+ return new IterationBodyResult(
+ DataStreamList.of(feedback), new DataStreamList(result), null);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/Als.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/Als.java
new file mode 100644
index 000000000..5ba788f1b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/Als.java
@@ -0,0 +1,491 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.als;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple5;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.ps.iterations.AllReduceStage;
+import org.apache.flink.ml.common.ps.iterations.IterationStageList;
+import org.apache.flink.ml.common.ps.iterations.PullStage;
+import org.apache.flink.ml.common.ps.iterations.PushStage;
+import org.apache.flink.ml.common.ps.utils.TrainingUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the Als algorithm.
+ *
+ * ALS tries to decompose a matrix R as R = X * Yt. Here X and Y are called factor matrices.
+ * Matrix R is usually a sparse matrix representing ratings given from users to items. ALS tries to
+ * find X and Y that minimize || R - X * Yt ||^2. This is done by iterations. At each step, X is
+ * fixed and Y is solved, then Y is fixed and X is solved.
+ *
+ *
The algorithm is described in "Large-scale Parallel Collaborative Filtering for the Netflix
+ * Prize, 2007". This algorithm also supports implicit preference model described in "Collaborative
+ * Filtering for Implicit Feedback Datasets, 2008".
+ */
+public class Als implements Estimator, AlsParams {
+ private final Map, Object> paramMap = new HashMap<>();
+ private static final int THRESHOLD = 100000;
+
+ public Als() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public AlsModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ final String userCol = getUserCol();
+ final String itemCol = getItemCol();
+ final String ratingCol = getRatingCol();
+
+ DataStream trainData = tEnv.toDataStream(inputs[0]);
+
+ DataStream> alsInput =
+ trainData
+ .map(
+ (MapFunction>)
+ value -> {
+ Long user = value.getFieldAs(userCol);
+ Long item = value.getFieldAs(itemCol);
+ user = 2L * user;
+ item = 2L * item + 1L;
+ Number rating =
+ ratingCol == null
+ ? 0.0F
+ : value.getFieldAs(ratingCol);
+
+ return new Tuple3<>(user, item, rating.doubleValue());
+ })
+ .name("generateInputALsData")
+ .returns(Types.TUPLE(Types.LONG, Types.LONG, Types.DOUBLE));
+
+ /* Initializes variables before iteration. */
+ DataStream ratingData = initRatings(alsInput);
+ int parallelism = ratingData.getParallelism();
+ AlsMLSession mlSession = new AlsMLSession(getImplicitPrefs(), getRank(), parallelism);
+ ExecutionConfig executionConfig = ratingData.getExecutionConfig();
+ TypeSerializer