diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 1773fc7d4..0e2cd9287 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -138,6 +138,11 @@ under the License. test test-jar + + it.unimi.dsi + fastutil + 8.5.12 + diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java new file mode 100644 index 000000000..36ca83269 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.util.Bits; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** + * {@link Message} is responsible for encoding all information exchanged between {@link + * WorkerOperator} and {@link ServerOperator}. The message format follows this structure: + * + *

`workerId serverId stageId keyLength keys valuesLength values` + * + *

where the message fields include the worker ID, server ID, stage ID, length of the keys, keys + * themselves, length of the values, and the values. + */ +public class Message { + private static final int WORKER_ID_OFFSET = 0; + private static final int SERVER_ID_OFFSET = Integer.BYTES; + private static final int STAGE_ID_OFFSET = Integer.BYTES + SERVER_ID_OFFSET; + private static final int KVS_OFFSET = Integer.BYTES + STAGE_ID_OFFSET; + + /** The storage of message in bytes. */ + public final byte[] bytes; + + /** Constructs a message instance from the bytes. */ + public Message(byte[] bytes) { + this.bytes = bytes; + } + + /** Constructs a message instance from long keys and double values. */ + public Message(int workerId, int serverId, int stageId, long[] keys, double[] values) { + int sizeInBytes = + KVS_OFFSET + + Bits.getLongArraySizeInBytes(keys) + + Bits.getDoubleArraySizeInBytes(values); + bytes = new byte[sizeInBytes]; + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + Bits.putInt(bytes, STAGE_ID_OFFSET, stageId); + int offset = Bits.putLongArray(keys, bytes, KVS_OFFSET); + Bits.putDoubleArray(values, bytes, offset); + } + + /** Constructs a message instance from long keys and generics values. */ + public Message( + int workerId, + int serverId, + int stageId, + long[] keys, + V[] values, + TypeSerializer serializer) + throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(byteArrayOutputStream); + dataOutputViewStreamWrapper.writeInt(workerId); + dataOutputViewStreamWrapper.writeInt(serverId); + dataOutputViewStreamWrapper.writeInt(stageId); + + dataOutputViewStreamWrapper.writeInt(keys.length); + for (long key : keys) { + dataOutputViewStreamWrapper.writeLong(key); + } + dataOutputViewStreamWrapper.writeInt(values.length); + for (V value : values) { + serializer.serialize(value, dataOutputViewStreamWrapper); + } + bytes = byteArrayOutputStream.toByteArray(); + } + + /** Retrieves the keys. */ + public long[] getKeys() { + return Bits.getLongArray(bytes, KVS_OFFSET); + } + + /** Retrieves the values using the given serializer. */ + public V[] getValues(TypeSerializer serializer) throws IOException { + int numIndices = Bits.getInt(bytes, KVS_OFFSET); + int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES; + int numValues = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + + // Since the generics got erased, we use reflections to create the array. + V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues); + ByteArrayInputStream byteArrayInputStream = + new ByteArrayInputStream(bytes, offset, bytes.length - offset); + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(byteArrayInputStream); + for (int i = 0; i < numValues; i++) { + result[i] = serializer.deserialize(dataInputViewStreamWrapper); + } + return result; + } + + /** + * Retrieves the values in double array. + * + *

Note that getting double array in this function using {@link Bits#getDoubleArray(byte[], + * int)} is faster than {@link Message#getValues} by up to 2.3X. + */ + public double[] getValuesInDoubleArray() { + int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES; + return Bits.getDoubleArray(bytes, offset); + } + + /** Retrieves the worker id. */ + public int getWorkerId() { + return Bits.getInt(bytes, WORKER_ID_OFFSET); + } + + /** Sets the worker id. */ + public void setWorkerId(int workerId) { + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + } + + /** Retrieves the server id. */ + public int getServerId() { + return Bits.getInt(bytes, SERVER_ID_OFFSET); + } + + /** Sets the server id. */ + public void setServerId(int serverId) { + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + } + + /** Retrieves the stage id. */ + public int getStageId() { + return Bits.getInt(bytes, STAGE_ID_OFFSET); + } + + /** + * Assembles the received messages from servers according to the server id. Note that these + * message should be the responses from the same stage. + */ + public static Message assembleMessages(Iterator messageIterator) { + List messages = new ArrayList<>(); + while (messageIterator.hasNext()) { + messages.add(new Message(messageIterator.next())); + } + messages.sort(Comparator.comparingInt(Message::getServerId)); + + int numMessages = messages.size(); + int numKeys = 0, numValues = 0; + int numAssembledBytes = 0; + int workerId = -1; + int stageId = -1; + for (Message message : messages) { + if (workerId == -1) { + workerId = message.getWorkerId(); + stageId = message.getStageId(); + } + numKeys += message.getNumKeys(); + numValues += message.getNumValues(); + numAssembledBytes += message.bytes.length; + } + numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2); + byte[] assembledBytes = new byte[numAssembledBytes]; + Bits.putInt(assembledBytes, WORKER_ID_OFFSET, workerId); + Bits.putInt(assembledBytes, STAGE_ID_OFFSET, stageId); + int keysOffset = KVS_OFFSET; + Bits.putInt(assembledBytes, keysOffset, numKeys); + keysOffset += Integer.BYTES; + int valuesOffset = keysOffset + numKeys * Long.BYTES; + Bits.putInt(assembledBytes, valuesOffset, numValues); + valuesOffset += Integer.BYTES; + + for (Message message : messages) { + Tuple2 keysOffsetAndLength = message.getKeysOffsetAndLength(); + System.arraycopy( + message.bytes, + keysOffsetAndLength.f0, + assembledBytes, + keysOffset, + keysOffsetAndLength.f1); + keysOffset += keysOffsetAndLength.f1; + Tuple2 valuesOffsetAndLength = message.getValuesOffSetAndLength(); + System.arraycopy( + message.bytes, + valuesOffsetAndLength.f0, + assembledBytes, + valuesOffset, + valuesOffsetAndLength.f1); + valuesOffset += valuesOffsetAndLength.f1; + } + + Message message = new Message(assembledBytes); + message.setServerId(-1); + return message; + } + + private Tuple2 getKeysOffsetAndLength() { + int start = KVS_OFFSET + Integer.BYTES; + int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES; + return Tuple2.of(start, numBytes); + } + + private Tuple2 getValuesOffSetAndLength() { + int start = + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + + KVS_OFFSET + + Integer.BYTES + + Integer.BYTES; + return Tuple2.of(start, bytes.length - start); + } + + private int getNumKeys() { + return Bits.getInt(bytes, KVS_OFFSET); + } + + private int getNumValues() { + return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java new file mode 100644 index 000000000..d8d3c095c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.longs.LongArrayList; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.Arrays; +import java.util.function.Function; + +/** + * ServerAgent resides on each worker. It serves as an agent for {@link WorkerOperator} to talk with + * {@link ServerOperator}. + */ +class ServerAgent { + /** Index of the worker that this agent resides on. */ + private final int workerId; + /** Number of servers to talk to. */ + private final int numServers; + /** Hash function to partition keys to different servers. */ + private final Function hashFunc; + /** The collector on this worker. */ + private final Output> output; + + ServerAgent( + int workerId, + int numServers, + Function hashFunc, + Output> output) { + this.workerId = workerId; + this.numServers = numServers; + this.output = output; + this.hashFunc = hashFunc; + } + + /** Pushes a key-value arrays to servers. */ + void push(SharedLongArray keys, SharedDoubleArray values, int stageId) { + Tuple2 slicedRequests = sliceRequest(keys, values); + LongArrayList[] splitKeys = slicedRequests.f0; + DoubleArrayList[] splitValues = slicedRequests.f1; + for (int serverId = 0; serverId < splitKeys.length; serverId++) { + Message message = + new Message( + workerId, + serverId, + stageId, + splitKeys[serverId].toLongArray(), + splitValues[serverId].toDoubleArray()); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + /** Pulls the values from servers with the specified keys. */ + void pull(SharedLongArray keys, int stageId) { + Tuple2 slicedRequests = sliceRequest(keys, null); + LongArrayList[] splitKeys = slicedRequests.f0; + for (int serverId = 0; serverId < splitKeys.length; serverId++) { + Message message = + new Message( + workerId, + serverId, + stageId, + splitKeys[serverId].toLongArray(), + new double[0]); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + /** + * Pushes the values to servers to apply all-reduce/reduce-scatter operation. + * + *

Note that the values pushed by this function are not going to update the model, but just + * perform an reduce operation. + */ + void reduce(V[] values, TypeSerializer typeSerializer, int stageId) throws IOException { + int shardSize = values.length / numServers + 1; + for (int serverId = 0; serverId < numServers; serverId++) { + int s = Math.min(serverId * shardSize, values.length); + int e = Math.min(s + shardSize, values.length); + V[] segment = Arrays.copyOfRange(values, s, e); + Message message = + new Message(workerId, serverId, stageId, new long[0], segment, typeSerializer); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + /** + * Splits the push/pull request according to the given sorted keys and the corresponding values. + * + * @param keys keys of push/pull request. + * @param values the push values if not null. + * @return the split requests for each server. + */ + private Tuple2 sliceRequest( + SharedLongArray keys, @Nullable SharedDoubleArray values) { + LongArrayList[] splitKeys = new LongArrayList[numServers]; + DoubleArrayList[] splitValues = new DoubleArrayList[numServers]; + for (int i = 0; i < numServers; i++) { + splitKeys[i] = new LongArrayList(); + splitValues[i] = new DoubleArrayList(); + } + + int numDoublesPerKey = 0; + if (values != null) { + Preconditions.checkState( + values.size() % keys.size() == 0, "The length of each key should be the same."); + numDoublesPerKey = values.size() / keys.size(); + } + + long[] keyArray = keys.elements(); + for (int i = 0; i < keys.size(); i++) { + int serverId = hashFunc.apply(keyArray[i]); + splitKeys[serverId].add(keyArray[i]); + if (values != null) { + for (int j = 0; j < numDoublesPerKey; j++) { + splitValues[serverId].add(values.get(i * numDoublesPerKey + j)); + } + } + } + + return Tuple2.of(splitKeys, splitValues); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java new file mode 100644 index 000000000..4edbbf0c4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStage; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PullStage.Aggregator; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingDeque; + +/** + * The server operator maintains the shared parameters. The shared parameters can be modeled as a + * collection of {key:value} pairs. By default, the keys are evenly distributed across servers + * through hash partitioning. For example, if there are two servers and the keys are {1,2,3,4,5,6}, + * then server-0 maintains keys {1,3,5} and server-1 maintains keys {2,4,6}. + * + *

The server receives push/pull/all-reduce/reduce-scatter requests from {@link WorkerOperator} + * and sends the answer request to {@link WorkerOperator}. It works closely with {@link + * ModelUpdater} in the following way: + * + *

+ * + *

Moreover, it accepts all-reduce/reduce-scatter request from workers and returns the reduced + * result to all workers. Note that the input of all-reduce/reduce-scatter operation is not going to + * be used in {@link ModelUpdater}. + * + * @param output format of model data. + */ +public class ServerOperator extends AbstractStreamOperator + implements OneInputStreamOperator, IterationListener { + /** The iterationStage list. */ + private final List stageList; + /** Number of workers to communicate with. */ + private final int numWorkers; + /** The logic to answer push/pull request from workers. */ + private final ModelUpdater modelUpdater; + /** Output tag of model data. */ + private final OutputTag modelOutputTag; + /** Index of the current server task. */ + private transient int serverId; + /** Thread pool to answer push/pull requests, to decouple the network and computation. */ + private transient ExecutorService singleThreadExecutor; + /** The future objects of thread calls in one epoch. */ + private transient List> futuresInEpoch; + /** + * The pending requests that server needs to send out responses (pull, all-reduce, + * reduce-scatter). + */ + private ListState pendingRequests; + /** + * The push request merged by stage id. We use map to store the merged push request since there + * may be consecutive pushes. + */ + private transient TreeMap accPushesByStage; + + private ListState accPushesByStageState; + + public ServerOperator( + List stageList, + int numWorkers, + ModelUpdater modelUpdater, + OutputTag modelOutputTag) { + this.stageList = stageList; + this.numWorkers = numWorkers; + this.modelUpdater = modelUpdater; + this.modelOutputTag = modelOutputTag; + } + + @Override + public void open() throws Exception { + super.open(); + this.serverId = getRuntimeContext().getIndexOfThisSubtask(); + this.singleThreadExecutor = Executors.newSingleThreadExecutor(); + this.futuresInEpoch = new ArrayList<>(); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + Message message = new Message(element.getValue()); + IterationStage stage = stageList.get(message.getStageId() % stageList.size()); + if (stage instanceof PushStage) { + futuresInEpoch.add(singleThreadExecutor.submit(() -> processPushRequest(message))); + } else if (stage instanceof PullStage + || stage instanceof AllReduceStage + || stage instanceof ReduceScatterStage) { + pendingRequests.add(message.bytes); + } else { + throw new IllegalStateException( + "Illegal iteration stage: " + stage.getClass().getSimpleName() + "."); + } + } + + @SuppressWarnings("unchecked") + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + // Waits until the pushes are processed. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + // Uses the merged pushes to update model. + for (Long2ObjectOpenHashMap currentAccPush : accPushesByStage.values()) { + if (currentAccPush.size() > 0) { + // The push is not empty. + int numDoublesPerKey; + Object object = currentAccPush.values().iterator().next(); + if (object instanceof Double) { + numDoublesPerKey = 1; + } else { + numDoublesPerKey = ((double[]) object).length; + } + + ObjectIterator> objectIterator = + currentAccPush.long2ObjectEntrySet().fastIterator(); + + long[] assembledKeys = new long[currentAccPush.size()]; + double[] assembledValues = new double[currentAccPush.size() * numDoublesPerKey]; + + int idx = 0; + if (numDoublesPerKey == 1) { + while (objectIterator.hasNext()) { + Map.Entry entry = + (Map.Entry) objectIterator.next(); + assembledKeys[idx] = entry.getKey(); + assembledValues[idx] = entry.getValue(); + idx++; + } + } else { + while (objectIterator.hasNext()) { + Map.Entry entry = + (Map.Entry) objectIterator.next(); + assembledKeys[idx] = entry.getKey(); + System.arraycopy( + entry.getValue(), + 0, + assembledValues, + idx * numDoublesPerKey, + numDoublesPerKey); + idx++; + } + } + currentAccPush.clear(); + modelUpdater.update(assembledKeys, assembledValues); + } + } + + // Deals with the pending requests, which should be one of Pull, AllReduce, ReduceScatter. + Iterator requestIterator = pendingRequests.get().iterator(); + if (requestIterator.hasNext()) { + Message message = new Message(requestIterator.next()); + int stageId = message.getStageId(); + IterationStage stage = stageList.get(stageId % stageList.size()); + requestIterator = pendingRequests.get().iterator(); + if (stage instanceof PullStage) { + final int blockingQueueCapacity = 20; + LinkedBlockingDeque pullsResponse = + new LinkedBlockingDeque<>(blockingQueueCapacity); + for (byte[] bytes : pendingRequests.get()) { + singleThreadExecutor.submit( + () -> processPullRequest(new Message(bytes), pullsResponse)); + } + int numResponsesSent = 0; + while (numResponsesSent < numWorkers) { + Message response = new Message(pullsResponse.take()); + output.collect(new StreamRecord<>(response.bytes)); + numResponsesSent++; + } + } else if (stage instanceof AllReduceStage) { + processAllReduceRequest(requestIterator); + } else if (stage instanceof ReduceScatterStage) { + processReduceScatterRequest(requestIterator); + } else { + throw new IllegalStateException( + "Illegal iteration stage: " + stage.getClass().getSimpleName() + "."); + } + + pendingRequests.clear(); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + Iterator modelSegments = modelUpdater.getModelSegments(); + while (modelSegments.hasNext()) { + MT modelSegment = modelSegments.next(); + output.collect(modelOutputTag, new StreamRecord<>(modelSegment)); + } + } + + @SuppressWarnings("unchecked") + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + pendingRequests = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pendingRequests", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + modelUpdater.initializeState(context); + + accPushesByStageState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accPushesByStageState", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + // Recovers accPushesByStage from a byte[] stream. + Iterator accPushesInBytes = accPushesByStageState.get().iterator(); + accPushesByStage = new TreeMap<>(); + + if (accPushesInBytes.hasNext()) { + // 4 bytes for number of stages. + byte[] meta = accPushesInBytes.next(); + int offset = 0; + int numberOfStages = Bits.getInt(meta, offset); + for (int i = 0; i < numberOfStages; i++) { + byte[] oneStageMeta = accPushesInBytes.next(); + offset = 0; + int stageId = Bits.getInt(oneStageMeta, offset); + offset += Integer.BYTES; + int sizeOfLong2ObjectMap = Bits.getInt(oneStageMeta, offset); + offset += Integer.BYTES; + int arrayLengthPerObject = Bits.getInt(oneStageMeta, offset); + Long2ObjectOpenHashMap pushes; + if (arrayLengthPerObject == 0) { + pushes = new Long2ObjectOpenHashMap(sizeOfLong2ObjectMap); + } else { + pushes = new Long2ObjectOpenHashMap(sizeOfLong2ObjectMap); + } + accPushesByStage.put(stageId, pushes); + for (int entryId = 0; entryId < sizeOfLong2ObjectMap; entryId++) { + byte[] kvInBytes = accPushesInBytes.next(); + long key = Bits.getLong(kvInBytes, 0); + if (arrayLengthPerObject == 0) { + Double value = Bits.getDouble(kvInBytes, Long.BYTES); + pushes.put(key, value); + } else { + double[] value = Bits.getDoubleArray(kvInBytes, Long.BYTES); + pushes.put(key, value); + } + } + } + } + } + + @SuppressWarnings("unchecked") + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + // Waits until the futures to finish. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + modelUpdater.snapshotState(context); + + accPushesByStageState.clear(); + // Writes accPushesByStage to state in the following format: + // numberOfStagesInInt, + // stageIdInInt, sizeOfLong2ObjectMapInInt, arrayLengthPerObject, key-value-long-obj... + // stageIdInInt, sizeOfLong2ObjectMapInInt, arrayLengthPerObject, key-value-long-obj... + if (accPushesByStage.size() > 0) { + int numberOfStages = accPushesByStage.size(); + byte[] meta = new byte[Integer.BYTES]; + Bits.putInt(meta, 0, numberOfStages); + accPushesByStageState.add(meta); + + for (Map.Entry entry : accPushesByStage.entrySet()) { + byte[] oneStageMeta = new byte[Integer.BYTES * 3]; + int offset = 0; + int stageId = entry.getKey(); + Bits.putInt(oneStageMeta, offset, stageId); + offset += Integer.BYTES; + int sizeOfLong2ObjectMap = entry.getValue().size(); + Bits.putInt(oneStageMeta, offset, sizeOfLong2ObjectMap); + offset += Integer.BYTES; + // 0 stands for Double, a non-zero value represents the array length. + int arrayLengthPerObject = 0; + + ObjectIterator> objectIterator = + entry.getValue().long2ObjectEntrySet().fastIterator(); + + if (objectIterator.hasNext()) { + Map.Entry oneEntry = objectIterator.next(); + if (oneEntry.getValue() instanceof double[]) { + arrayLengthPerObject = ((double[]) (oneEntry.getValue())).length; + } + Bits.putInt(oneStageMeta, offset, arrayLengthPerObject); + accPushesByStageState.add(oneStageMeta); + + accPushesByStageState.add(kvToBytes(oneEntry)); + while (objectIterator.hasNext()) { + accPushesByStageState.add(kvToBytes(objectIterator.next())); + } + } + } + } + } + + private static byte[] kvToBytes(Map.Entry obj) { + byte[] bytes; + if (obj.getValue() instanceof double[]) { + double[] value = (double[]) obj.getValue(); + bytes = new byte[Long.BYTES + Bits.getDoubleArraySizeInBytes(value)]; + Bits.putLong(bytes, 0, obj.getKey()); + Bits.putDoubleArray(value, bytes, Long.BYTES); + } else { + bytes = new byte[Long.BYTES + Double.BYTES]; + Bits.putLong(bytes, 0, obj.getKey()); + Bits.putDouble(bytes, Long.BYTES, (Double) obj.getValue()); + } + return bytes; + } + + @SuppressWarnings("unchecked") + private Object processPushRequest(Message message) throws Exception { + long[] keys = message.getKeys(); + int stageId = message.getStageId(); + double[] values = message.getValuesInDoubleArray(); + + accPushesByStage.putIfAbsent(stageId, new Long2ObjectOpenHashMap()); + Long2ObjectOpenHashMap currentAccKvs = accPushesByStage.get(stageId); + + if (keys.length != 0) { + ReduceFunction reduceFunc = + ((PushStage) stageList.get(stageId % stageList.size())).reduceFunc; + if (values.length == keys.length) { + for (int i = 0; i < keys.length; i++) { + if (currentAccKvs.containsKey(keys[i])) { + double currentVal = (Double) currentAccKvs.get(keys[i]); + currentAccKvs.put(keys[i], reduceFunc.reduce(currentVal, values[i])); + } else { + currentAccKvs.put(keys[i], (Double) values[i]); + } + } + } else { + int numDoublesPerKey = values.length / keys.length; + for (int i = 0; i < keys.length; i++) { + if (currentAccKvs.containsKey(keys[i])) { + double[] currentVal = (double[]) currentAccKvs.get(keys[i]); + for (int j = 0; j < numDoublesPerKey; j++) { + currentVal[j] = + reduceFunc.reduce( + currentVal[j], values[i * numDoublesPerKey + j]); + } + } else { + currentAccKvs.put( + keys[i], + Arrays.copyOfRange( + values, + i * numDoublesPerKey, + i * numDoublesPerKey + numDoublesPerKey)); + } + } + } + } + return new Object(); + } + + private void processPullRequest(Message message, LinkedBlockingDeque pullsResponse) { + int workerId = message.getWorkerId(); + long[] keys = message.getKeys(); + Message response; + + if (keys.length == 0) { + // No request on this server. + response = + new Message( + workerId, serverId, message.getStageId(), new long[0], new double[0]); + } else { + double[] pulledValues = modelUpdater.get(keys); + Preconditions.checkState(pulledValues.length % keys.length == 0); + int numDoublesPerKey = pulledValues.length / keys.length; + + double[] aggregatedPullValues = null; + Aggregator aggregator = + ((PullStage) (stageList.get(message.getStageId() % stageList.size()))) + .aggregator; + if (aggregator != null) { + // Processes the pulled values if the aggregator is not null. + double[] tmp = new double[numDoublesPerKey]; + for (int i = 0; i < keys.length; i++) { + System.arraycopy(pulledValues, i * numDoublesPerKey, tmp, 0, numDoublesPerKey); + aggregatedPullValues = aggregator.add(tmp, aggregatedPullValues); + } + } else { + aggregatedPullValues = pulledValues; + } + + response = + new Message( + workerId, + serverId, + message.getStageId(), + new long[0], + aggregatedPullValues); + } + while (!pullsResponse.offer(response.bytes)) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + @SuppressWarnings("unchecked") + private void processAllReduceRequest(Iterator requests) throws Exception { + byte[] request = requests.next(); + Message message = new Message(request); + int stageId = message.getStageId(); + AllReduceStage stage = (AllReduceStage) stageList.get(stageId % stageList.size()); + V[] reducedResult = message.getValues(stage.typeSerializer); + ReduceFunction reduceFunction = stage.reducer; + + while (requests.hasNext()) { + message = new Message(requests.next()); + reducedResult = + reduceFunction.reduce(message.getValues(stage.typeSerializer), reducedResult); + } + message = + new Message( + -1, serverId, stageId, new long[0], reducedResult, stage.typeSerializer); + + for (int workerId = 0; workerId < numWorkers; workerId++) { + message.setWorkerId(workerId); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + @SuppressWarnings("unchecked") + private void processReduceScatterRequest(Iterator requests) throws Exception { + byte[] request = requests.next(); + Message message = new Message(request); + int stageId = message.getStageId(); + ReduceScatterStage stage = + (ReduceScatterStage) stageList.get(stageId % stageList.size()); + V[] reducedResult = message.getValues(stage.typeSerializer); + ReduceFunction reduceFunction = stage.reducer; + + while (requests.hasNext()) { + message = new Message(requests.next()); + reducedResult = + reduceFunction.reduce(message.getValues(stage.typeSerializer), reducedResult); + } + + int[] recvCounts = stage.recvCounts; + int totalCnt = Arrays.stream(recvCounts).sum(); + int shardSize = totalCnt / getRuntimeContext().getNumberOfParallelSubtasks() + 1; + int sliceStart = Math.min(serverId * shardSize, totalCnt); + int sliceEnd = Math.min(sliceStart + shardSize, totalCnt); + + int s = 0; + int e; + for (int workerId = 0; workerId < numWorkers; workerId++) { + e = recvCounts[workerId] + s; + + int intersectionStart = Math.max(s, sliceStart); + int interSectionEnd = Math.min(e, sliceEnd); + int copyStart = 0, copyEnd = 0; + if (interSectionEnd > intersectionStart) { + copyStart = intersectionStart - sliceStart; + copyEnd = interSectionEnd - sliceStart; + } + message = + new Message( + workerId, + serverId, + stageId, + new long[0], + Arrays.copyOfRange(reducedResult, copyStart, copyEnd), + stage.typeSerializer); + output.collect(new StreamRecord<>(message.bytes)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java new file mode 100644 index 000000000..cf935a6bc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStage; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.MLSession; +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.ml.common.ps.utils.ProxySideOutput; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +import java.util.Iterator; +import java.util.function.Function; + +/** + * The worker operator that executes the iterative machine learning process following {@link + * IterationStageList}. + * + *

In detail, the worker operator is responsible for the following: + * + *

+ */ +public class WorkerOperator extends AbstractStreamOperator + implements TwoInputStreamOperator, IterationListener { + /** The user defined iteration logic. */ + private final IterationStageList iterationStages; + /** + * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages + * in stageList, the iteration id increments by one. + */ + private int iterationId; + + /** The id of the stages to execute in iterationStages. */ + private int nextStageToExecute = 0; + + private ListState nextStageToExecuteState; + + /** The agent for each worker to talk with servers. */ + private transient ServerAgent serverAgent; + /** Number of servers that this worker needs to talk to. */ + private final int numServers; + /** The hash function to distribute keys to servers. */ + private transient Function hashFunc; + + /** The cached training data. */ + private ListStateWithCache
trainDataState; + + /** + * Number of segments received from servers for the current request. For each request, a worker + * should receive one segment from each server. + */ + private int numSegmentsReceived = 0; + + private ListState numSegmentsReceivedState; + + /** + * The memory store for pull answer. For a pull request, each received segment will be filled to + * the user provided buffer. + */ + private double[] pulledResult; + + private ListState pulledResultState; + + /** The state store for the all-reduce/reduce-scatter results. */ + private ListState reducedResult; + + public WorkerOperator(IterationStageList iterationStages, int numServers) { + this.iterationStages = iterationStages; + this.numServers = numServers; + } + + @Override + public void open() { + int workerId = getRuntimeContext().getIndexOfThisSubtask(); + int numWorkers = getRuntimeContext().getNumberOfParallelSubtasks(); + this.hashFunc = key -> (int) (Math.abs(key % numServers)); + this.serverAgent = new ServerAgent(workerId, numServers, hashFunc, output); + iterationStages.session.setWorldInfo(workerId, numWorkers); + iterationStages.session.setOutput(new ProxySideOutput(output)); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + if (epochWatermark == 0) { + iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState)); + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + trainDataState.clear(); + } + + @Override + public void processElement1(StreamRecord
streamRecord) throws Exception { + trainDataState.add(streamRecord.getValue()); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + Message message = new Message(streamRecord.getValue()); + IterationStage stage = + iterationStages.stageList.get( + nextStageToExecute % iterationStages.stageList.size()); + + boolean proceedToNextStage; + if (stage instanceof PullStage) { + proceedToNextStage = onPullResponse(message, (PullStage) stage); + } else if (stage instanceof AllReduceStage) { + proceedToNextStage = onAllReduceResponse(message, (AllReduceStage) stage); + } else if (stage instanceof ReduceScatterStage) { + proceedToNextStage = onReduceScatterResponse(message, (ReduceScatterStage) stage); + } else { + throw new IllegalStateException( + "Illegal stage type: %s" + stage.getClass().getSimpleName() + "."); + } + + if (proceedToNextStage) { + nextStageToExecute++; + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); + } + } + + private boolean onPullResponse(Message message, PullStage pullStage) { + numSegmentsReceived++; + double[] segment = message.getValuesInDoubleArray(); + if (segment.length != 0) { + if (pullStage.aggregator != null) { + if (pulledResult.length == 0) { + pulledResult = segment; + } else { + pulledResult = pullStage.aggregator.merge(segment, pulledResult); + } + } else { + SharedLongArray keys = pullStage.keys.get(); + SharedDoubleArray values = pullStage.values.get(); + int serverId = message.getServerId(); + long[] keysArray = keys.elements(); + + if (pulledResult.length == 0) { + pulledResult = values.elements(); + } + + int numDoublesPerKey = values.size() / keys.size(); + // Copy the response from one server to the result array. + int idxInLocalPull = 0; + for (int i = 0; i < keys.size(); i++) { + if (hashFunc.apply(keysArray[i]) == serverId) { + System.arraycopy( + segment, + idxInLocalPull * numDoublesPerKey, + pulledResult, + i * numDoublesPerKey, + numDoublesPerKey); + idxInLocalPull++; + } + } + } + } + + if (numSegmentsReceived == numServers) { + SharedDoubleArray pullPlaceHolder = pullStage.values.get(); + System.arraycopy( + pulledResult, 0, pullPlaceHolder.elements(), 0, pullPlaceHolder.size()); + + pulledResult = new double[0]; + numSegmentsReceived = 0; + return true; + } + return false; + } + + private boolean onAllReduceResponse(Message message, AllReduceStage allReduceStage) + throws Exception { + numSegmentsReceived++; + reducedResult.add(message.bytes); + + if (numSegmentsReceived == numServers) { + Message assembled = Message.assembleMessages(reducedResult.get().iterator()); + V[] reduceResult = assembled.getValues(allReduceStage.typeSerializer); + System.arraycopy(reduceResult, 0, allReduceStage.recvBuf.get(), 0, reduceResult.length); + reducedResult.clear(); + numSegmentsReceived = 0; + return true; + } + return false; + } + + private boolean onReduceScatterResponse( + Message message, ReduceScatterStage reduceScatterStage) throws Exception { + numSegmentsReceived++; + reducedResult.add(message.bytes); + + if (numSegmentsReceived == numServers) { + Message assembled = Message.assembleMessages(reducedResult.get().iterator()); + V[] reduceResult = assembled.getValues(reduceScatterStage.typeSerializer); + System.arraycopy( + reduceResult, 0, reduceScatterStage.recvBuf.get(), 0, reduceResult.length); + reducedResult.clear(); + numSegmentsReceived = 0; + return true; + } + return false; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + trainDataState = + new ListStateWithCache<>( + (getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader())), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + + numSegmentsReceivedState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("numSegmentsReceivedState", Types.INT)); + numSegmentsReceived = + OperatorStateUtils.getUniqueElement( + numSegmentsReceivedState, "numSegmentsReceivedState") + .orElse(0); + + nextStageToExecuteState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("nextStageToExecuteState", Types.INT)); + + nextStageToExecute = + OperatorStateUtils.getUniqueElement( + nextStageToExecuteState, "nextStageToExecuteState") + .orElse(0); + + iterationStages.session.initializeState(context); + + pulledResultState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pulledResultState", + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + pulledResult = + OperatorStateUtils.getUniqueElement(pulledResultState, "pulledResultState") + .orElse(new double[0]); + + reducedResult = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "reducedResult", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + + numSegmentsReceivedState.clear(); + numSegmentsReceivedState.add(numSegmentsReceived); + + nextStageToExecuteState.clear(); + nextStageToExecuteState.add(nextStageToExecute); + + trainDataState.snapshotState(context); + iterationStages.session.snapshotState(context); + + pulledResultState.clear(); + pulledResultState.add(pulledResult); + } + + /** + * Processes the stages described in the given iterationStages from the given nextStage id. This + * function processes the stages until it meets a {@link PullStage}, {@link AllReduceStage} or + * {@link ReduceScatterStage}. + * + * @param nextStageToExecute id of the next stage to execute in the given iteration stages. + * @param iterationStages iteration stages used to describe the training logic. + * @return the id of the next stage to execute. + */ + @SuppressWarnings("unchecked") + private int processIterationStages( + int nextStageToExecute, IterationStageList iterationStages) throws Exception { + while (true) { + if (nextStageToExecute > 0 + && nextStageToExecute % iterationStages.stageList.size() == 0) { + iterationId = nextStageToExecute / iterationStages.stageList.size(); + iterationStages.session.setIterationId(iterationId); + if (iterationStages.shouldTerminate.apply(iterationStages.session)) { + return -1; + } + } + IterationStage stage = + iterationStages.stageList.get( + nextStageToExecute % iterationStages.stageList.size()); + + // We are not incrementing nextStageToExecute for + // PullStage/AllReduceStage/ReduceScatterStage, since we + // need to wait for response from servers. + if (stage instanceof PullStage) { + PullStage pullStage = ((PullStage) stage); + serverAgent.pull(pullStage.keys.get(), nextStageToExecute); + return nextStageToExecute; + + } else if (stage instanceof AllReduceStage) { + AllReduceStage allReduceStage = (AllReduceStage) stage; + if (iterationId % allReduceStage.executionInterval == 0) { + serverAgent.reduce( + allReduceStage.sendBuf.get(), + allReduceStage.typeSerializer, + nextStageToExecute); + return nextStageToExecute; + } else { + nextStageToExecute++; + } + + } else if (stage instanceof ReduceScatterStage) { + ReduceScatterStage reduceScatterStage = (ReduceScatterStage) stage; + if (iterationId % reduceScatterStage.executionInterval == 0) { + serverAgent.reduce( + reduceScatterStage.sendBuf.get(), + reduceScatterStage.typeSerializer, + nextStageToExecute); + return nextStageToExecute; + } else { + nextStageToExecute++; + } + } else if (stage instanceof PushStage) { + PushStage pushStage = (PushStage) stage; + serverAgent.push(pushStage.keys.get(), pushStage.values.get(), nextStageToExecute); + nextStageToExecute++; + } else if (stage instanceof ProcessStage) { + ((ProcessStage) stage).process(iterationStages.session); + nextStageToExecute++; + } else { + throw new IllegalStateException( + "Illegal type of IterationStage: + " + + stage.getClass().getSimpleName() + + "."); + } + } + } + + /** A resettable iterator for {@link ListStateWithCache}. */ + private static class ResettableTrainDataIterator implements ResettableIterator { + private final ListStateWithCache data; + private Iterator dataIterator; + + public ResettableTrainDataIterator(ListStateWithCache data) throws Exception { + this.data = data; + this.dataIterator = data.get().iterator(); + } + + @Override + public void reset() { + try { + this.dataIterator = data.get().iterator(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean hasNext() { + return dataIterator.hasNext(); + } + + @Override + public T next() { + return dataIterator.next(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java new file mode 100644 index 000000000..c0d3cdf0b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Supplier; + +/** + * This iteration stage is designed to perform an all-reduce operation on the specified array in a + * distributed setting. + * + *

Users can specify how often this operation is conducted by setting the value of the + * "executionInterval" parameter, which determines the frequency of the all-reduce stage. For + * example, if the value of executionInterval is set to 5, the all-reduce stage will be executed + * every 5 iterations. + */ +public final class AllReduceStage implements IterationStage { + public final Supplier sendBuf; + public final Supplier recvBuf; + public final ReduceFunction reducer; + public final TypeSerializer typeSerializer; + public final int executionInterval; + + public AllReduceStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + ReduceFunction reducer, + TypeSerializer typeSerializer, + int executionInterval) { + this.sendBuf = sendBuf; + this.recvBuf = recvBuf; + this.reducer = reducer; + this.typeSerializer = typeSerializer; + this.executionInterval = executionInterval; + } + + public AllReduceStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + ReduceFunction reducer, + TypeSerializer typeSerializer) { + this(sendBuf, recvBuf, reducer, typeSerializer, 1); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java new file mode 100644 index 000000000..d0f23a774 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import java.io.Serializable; + +/** + * Iterative machine learning training usually incurs local computation step (e.g., computing + * gradients) and global communication step (e.g., all-reduce and parameter servers to aggregate the + * updates from workers). + * + *

To describe the above iteration training process, we model the training process as a sequence + * of iteration stages. An iteration stage could be either local computation or global + * communication. + */ +public interface IterationStage extends Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java new file mode 100644 index 000000000..e1cd23b7b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * A list of iteration stages to express the logic of an iterative machine learning process. + * + *

Note that there should be at least one stage (e.g., {@link PullStage}, {@link AllReduceStage} + * or {@link ReduceScatterStage}) that needs to wait for responses from servers. + */ +public class IterationStageList implements Serializable { + /** The session on each worker. */ + public final T session; + /** The termination criteria. */ + public Function shouldTerminate; + /** The stage list that describes the iterative process. */ + public List stageList; + + public IterationStageList(T session) { + this.stageList = new ArrayList<>(); + this.session = session; + } + + /** Sets the criteria of termination. */ + public IterationStageList setTerminationCriteria( + SerializableFunction shouldTerminate) { + boolean waitServer = false; + for (IterationStage stage : stageList) { + if (stage instanceof PullStage + || stage instanceof AllReduceStage + || stage instanceof ReduceScatterStage) { + waitServer = true; + break; + } + } + Preconditions.checkState( + waitServer, + String.format( + "There should be at least one stage that needs to receive response from servers (i.e., %s, %s, %s).\n", + PullStage.class.getSimpleName(), + AllReduceStage.class.getSimpleName(), + ReduceScatterStage.class.getSimpleName())); + this.shouldTerminate = shouldTerminate; + return this; + } + + /** Adds an iteration stage into the stage list. */ + public IterationStageList addStage(IterationStage stage) { + stageList.add(stage); + return this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java new file mode 100644 index 000000000..799a5cb6a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.ml.common.ps.utils.ProxySideOutput; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.io.Serializable; +import java.util.List; + +/** + * Stores the session information that is alive during the training process on {@link + * WorkerOperator}. Note that the session information will be updated by each {@link + * IterationStage}. + * + *

Subclasses should take care of the snapshot of object stored in {@link MLSession} if the + * object satisfies that: the write-process is followed by a {@link PullStage}/{@link + * AllReduceStage}/{@link ReduceScatterStage}, which is later again read by other stages. + */ +public interface MLSession extends Serializable { + /** Sets the current iteration ID. */ + default void setIterationId(int iterationId) {} + + /** Sets the worker id and total number of workers. */ + default void setWorldInfo(int workerId, int numWorkers) {} + + /** Sets the training data. */ + default void setInputData(ResettableIterator inputData) {} + + /** Sets the collector that users can output records to downstream tasks. */ + default void setOutput(ProxySideOutput collector) {} + + /** + * Retrieves the output tags from the {@link MLSession} which can be used to output records from + * the worker operator. + */ + default List> getOutputTags() { + return null; + } + + /** Recovers from state. */ + default void initializeState(StateInitializationContext context) throws Exception {} + + /** Snapshots to state. */ + default void snapshotState(StateSnapshotContext context) throws Exception {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java new file mode 100644 index 000000000..317f7cb66 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.util.List; + +/** + * The default implementation of {@link MLSession}. + * + * @param

Data type of input data. + */ +public class MLSessionImpl
implements MLSession { + /** Current iteration id. */ + public int iterationId; + /** Index of this worker. */ + public int workerId; + /** Number of workers in total for this distributed ML job. */ + public int numWorkers; + /** The input data. */ + public ResettableIterator
inputData; + + public List> outputTags; + + /** Constructs an instance with side outputs. */ + public MLSessionImpl(List> outputTags) { + this.outputTags = outputTags; + } + + /** Constructs an instance without side outputs. */ + public MLSessionImpl() { + this(null); + } + + @Override + public List> getOutputTags() { + return outputTags; + } + + @Override + public void setIterationId(int iterationId) { + this.iterationId = iterationId; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + this.workerId = workerId; + this.numWorkers = numWorkers; + } + + @Override + public void setInputData(ResettableIterator inputData) { + this.inputData = (ResettableIterator
) inputData; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java new file mode 100644 index 000000000..8c8810699 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +/** + * A local computation stage of the training process. The input and output of {@link ProcessStage} + * can be accessed via {@link MLSession}. + * + * @param Type of the training data. + */ +public abstract class ProcessStage implements IterationStage { + /** + * Does a local computation logic using the information from session. Example stages could be + * computing gradients. + */ + public abstract void process(T session) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java new file mode 100644 index 000000000..8f86c5e5c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.util.function.SerializableSupplier; + +import java.io.Serializable; +import java.util.function.Supplier; + +/** + * An iteration stage that aggregates data from servers using keys as {@code PullStage#keys#get()} + * and stores the aggregated values by {@code PullStage#values#get()}. + * + *

If the aggregator is null, we simply pull those values specified by the keys. + */ +public final class PullStage implements IterationStage { + public final Supplier keys; + public final Supplier values; + public final Aggregator aggregator; + + public PullStage( + SerializableSupplier keys, + SerializableSupplier values) { + this(keys, values, null); + } + + public PullStage( + SerializableSupplier keys, + SerializableSupplier values, + Aggregator aggregator) { + this.keys = keys; + this.values = values; + this.aggregator = aggregator; + } + + /** + * An Aggregator is used to aggregate a set of input elements into a single accumulator. + * + * @param The type of the input elements. + * @param The type of the accumulator. + */ + @Internal + public interface Aggregator extends Serializable { + + /** + * Adds a new input element to the given accumulator and returns the updated accumulator. + * + * @param in The input element to add. + * @param acc The accumulator to update. + * @return The updated accumulator. + */ + ACC add(IN in, ACC acc); + + /** + * Merges two accumulators and returns the result. + * + * @param acc1 The first accumulator to merge. + * @param acc2 The second accumulator to merge. + * @return The merged accumulator. + */ + ACC merge(ACC acc1, ACC acc2); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java new file mode 100644 index 000000000..3abec6190 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Supplier; + +/** + * An iteration stage that push (indices, values) to servers. User can specify how values from + * different workers are merged via {@code PushStage#reduceFunc}. By default, the values are summed + * from different workers. + * + *

Note that the length of the values array must be divisible by the length of the keys array. + * Additionally, each value corresponding to a given key must have the same length. For instance, + * considering the keys {1, 4} and values {1,2,3,4,5,6}, the value at index 1 would be {1,2,3}, and + * the value at index 4 would be {4,5,6}. + */ +public class PushStage implements IterationStage { + public final Supplier keys; + public final Supplier values; + + /** The function to reduce the pushes from all workers. For gradient descent based methods, */ + public final ReduceFunction reduceFunc; + + public PushStage( + SerializableSupplier keys, + SerializableSupplier values) { + this(keys, values, Double::sum); + } + + public PushStage( + SerializableSupplier keys, + SerializableSupplier values, + ReduceFunction reduceFunc) { + this.keys = keys; + this.values = values; + this.reduceFunc = reduceFunc; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java new file mode 100644 index 000000000..c660de285 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Supplier; + +/** + * This iteration stage is designed to perform an reduce-scatter operation on the specified array in + * a distributed setting. + * + *

Users can specify how often this operation is conducted by setting the value of the + * "executionInterval" parameter, which determines the frequency of the all-reduce stage. For + * example, if the value of executionInterval is set to 5, the all-reduce stage will be executed + * every 5 iterations. + */ +public final class ReduceScatterStage implements IterationStage { + public final Supplier sendBuf; + public final Supplier recvBuf; + /** The number of elements each worker receives. */ + public int[] recvCounts; + + public final ReduceFunction reducer; + public final TypeSerializer typeSerializer; + + public final int executionInterval; + + public ReduceScatterStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + int[] recvCounts, + ReduceFunction reducer, + TypeSerializer typeSerializer, + int executionInterval) { + this.sendBuf = sendBuf; + this.recvBuf = recvBuf; + this.recvCounts = Preconditions.checkNotNull(recvCounts); + this.reducer = reducer; + this.typeSerializer = typeSerializer; + this.executionInterval = executionInterval; + } + + public ReduceScatterStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + int[] recvCounts, + ReduceFunction reducer, + TypeSerializer typeSerializer) { + this(sendBuf, recvBuf, recvCounts, reducer, typeSerializer, 1); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java new file mode 100644 index 000000000..4a7aa24b0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.sarray; + +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; + +import java.io.Serializable; + +/** A resizable double array that can be shared among different iterations for memory efficiency. */ +public class SharedDoubleArray implements Serializable { + + /** The underlying DoubleArrayList that holds the elements. */ + private final DoubleArrayList doubles; + + /** + * Constructs a new SDArray from the given double array. + * + * @param array the double array to wrap + */ + public SharedDoubleArray(double[] array) { + doubles = DoubleArrayList.wrap(array); + } + + /** + * Constructs a new SDArray with the given initial capacity. + * + * @param capacity the initial capacity + */ + public SharedDoubleArray(int capacity) { + doubles = new DoubleArrayList(capacity); + } + + /** Constructs a new empty SDArray. */ + public SharedDoubleArray() { + doubles = new DoubleArrayList(); + } + + /** + * Returns the element at the specified index. + * + * @param index the index of the element to return + * @return the element at the specified index + */ + public double get(int index) { + return doubles.getDouble(index); + } + + /** + * Appends the specified element to the end of this array. + * + * @param v the element to add + */ + public void add(double v) { + doubles.add(v); + } + + /** + * Appends all the elements from the specified double array to the end of this array. + * + * @param src the double array to append + */ + public void addAll(double[] src) { + int sizeBefore = size(); + doubles.size(sizeBefore + src.length); + System.arraycopy(src, 0, elements(), sizeBefore, src.length); + } + + /** + * Returns the number of valid elements in this array. + * + * @return the number of valid elements in this array + */ + public int size() { + return doubles.size(); + } + + /** + * Sets the size of the array to the provided size. If the new size is larger than the current + * size, the new allocated memory are filled with zero. + * + * @param size the new size of the array + */ + public void size(int size) { + doubles.size(size); + } + + /** Clears the elements in this array. Note that the memory is not recycled. */ + public void clear() { + doubles.clear(); + } + + /** + * Returns a double array containing all the elements in this array. Only the first {@link + * SharedDoubleArray#size()} elements are valid. + * + * @return a double array containing the all the elements in this array + */ + public double[] elements() { + return doubles.elements(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java new file mode 100644 index 000000000..d193890da --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.sarray; + +import it.unimi.dsi.fastutil.longs.LongArrayList; + +import java.io.Serializable; + +/** A resizable long array that can be shared among different iterations for memory efficiency. */ +public class SharedLongArray implements Serializable { + + /** The underlying LongArrayList that holds the elements. */ + private final LongArrayList longs; + + /** + * Constructs a new SLArray from the given long array. + * + * @param array the long array to wrap + */ + public SharedLongArray(long[] array) { + longs = LongArrayList.wrap(array); + } + + /** + * Constructs a new SLArray with the given initial capacity. + * + * @param capacity the initial capacity + */ + public SharedLongArray(int capacity) { + longs = new LongArrayList(capacity); + } + + /** Constructs a new empty SLArray. */ + public SharedLongArray() { + longs = new LongArrayList(); + } + + /** + * Returns the element at the specified index. + * + * @param index the index of the element to return + * @return the element at the specified index + */ + public long get(int index) { + return longs.getLong(index); + } + + /** + * Appends the specified element to the end of this array. + * + * @param v the element to add + */ + public void add(long v) { + longs.add(v); + } + + /** + * Appends all the elements from the specified long array to the end of this array. + * + * @param src the long array to append + */ + public void addAll(long[] src) { + int sizeBefore = size(); + longs.size(sizeBefore + src.length); + System.arraycopy(src, 0, elements(), sizeBefore, src.length); + } + + /** + * Returns the number of valid elements in this array. + * + * @return the number of valid elements in this array + */ + public int size() { + return longs.size(); + } + + /** + * Resizes this array to the specified size. Sets the size of the array to the provided size. If + * the new size is larger than the current size, the new allocated memory are filled with zero. + * + * @param size the new size of the array + */ + public void size(int size) { + longs.size(size); + } + + /** Clears the elements in this array. Note that the memory is not recycled. */ + public void clear() { + longs.clear(); + } + + /** + * Returns a long array containing the valid elements in this array. Only the first {@link + * SharedLongArray#size()} elements are valid. + * + * @return a long array containing the valid elements in this array + */ + public long[] elements() { + return longs.elements(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java new file mode 100644 index 000000000..3e2d3b920 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.io.IOException; +import java.util.Map; + +/** TypeSerializer for {@link Long2DoubleOpenHashMap}. */ +public class Long2DoubleOpenHashMapSerializer extends TypeSerializer { + + public static final Long2DoubleOpenHashMapSerializer INSTANCE = + new Long2DoubleOpenHashMapSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return INSTANCE; + } + + @Override + public Long2DoubleOpenHashMap createInstance() { + return new Long2DoubleOpenHashMap(); + } + + @Override + public Long2DoubleOpenHashMap copy(Long2DoubleOpenHashMap from) { + return new Long2DoubleOpenHashMap(from); + } + + @Override + public Long2DoubleOpenHashMap copy(Long2DoubleOpenHashMap from, Long2DoubleOpenHashMap reuse) { + return new Long2DoubleOpenHashMap(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Long2DoubleOpenHashMap map, DataOutputView target) throws IOException { + target.writeInt(map.size()); + for (Map.Entry entry : map.entrySet()) { + target.writeLong(entry.getKey()); + target.writeDouble(entry.getValue()); + } + } + + @Override + public Long2DoubleOpenHashMap deserialize(DataInputView source) throws IOException { + int numEntries = source.readInt(); + Long2DoubleOpenHashMap map = new Long2DoubleOpenHashMap(numEntries); + for (int i = 0; i < numEntries; i++) { + long k = source.readLong(); + double v = source.readDouble(); + map.put(k, v); + } + return map; + } + + @Override + public Long2DoubleOpenHashMap deserialize(Long2DoubleOpenHashMap reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + int numEntries = source.readInt(); + target.writeInt(numEntries); + for (int i = 0; i < numEntries; ++i) { + target.writeLong(source.readLong()); + target.writeDouble(source.readDouble()); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new Long2DoubleOpenHashMapSnapshot(); + } + + private static final class Long2DoubleOpenHashMapSnapshot + extends SimpleTypeSerializerSnapshot { + public Long2DoubleOpenHashMapSnapshot() { + super(() -> Long2DoubleOpenHashMapSerializer.INSTANCE); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java new file mode 100644 index 000000000..4acfafdc1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +/** TypeInformation for {@link Long2DoubleOpenHashMap}. */ +public class Long2DoubleOpenHashMapTypeInfo extends TypeInformation { + + public static Long2DoubleOpenHashMapTypeInfo instance = new Long2DoubleOpenHashMapTypeInfo(); + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + public int getTotalFields() { + return 1; + } + + @Override + public Class getTypeClass() { + return Long2DoubleOpenHashMap.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig config) { + return Long2DoubleOpenHashMapSerializer.INSTANCE; + } + + @Override + public String toString() { + return "Long2DoubleOpenHashMap Type"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof Long2DoubleOpenHashMapTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java new file mode 100644 index 000000000..12f250d7d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * TypeSerializer for {@link Long2ObjectOpenHashMap}. + * + * @param The type of elements in the Long2ObjectOpenHashMap. + */ +public class Long2ObjectOpenHashMapSerializer extends TypeSerializer> { + + private final TypeSerializer elementSerializer; + + public Long2ObjectOpenHashMapSerializer(TypeSerializer elementSerializer) { + this.elementSerializer = Preconditions.checkNotNull(elementSerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer> duplicate() { + return new Long2ObjectOpenHashMapSerializer<>(elementSerializer.duplicate()); + } + + @Override + public Long2ObjectOpenHashMap createInstance() { + return new Long2ObjectOpenHashMap<>(); + } + + @Override + public Long2ObjectOpenHashMap copy(Long2ObjectOpenHashMap from) { + return new Long2ObjectOpenHashMap<>(from); + } + + @Override + public Long2ObjectOpenHashMap copy( + Long2ObjectOpenHashMap from, Long2ObjectOpenHashMap reuse) { + return new Long2ObjectOpenHashMap<>(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Long2ObjectOpenHashMap map, DataOutputView target) throws IOException { + target.writeInt(map.size()); + for (Map.Entry entry : map.entrySet()) { + target.writeLong(entry.getKey()); + elementSerializer.serialize(entry.getValue(), target); + } + } + + @Override + public Long2ObjectOpenHashMap deserialize(DataInputView source) throws IOException { + int numEntries = source.readInt(); + Long2ObjectOpenHashMap map = new Long2ObjectOpenHashMap<>(numEntries); + for (int i = 0; i < numEntries; i++) { + long k = source.readLong(); + T v = elementSerializer.deserialize(source); + map.put(k, v); + } + return map; + } + + @Override + public Long2ObjectOpenHashMap deserialize( + Long2ObjectOpenHashMap reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + int numEntries = source.readInt(); + target.writeInt(numEntries); + for (int i = 0; i < numEntries; ++i) { + target.writeLong(source.readLong()); + elementSerializer.copy(source, target); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + Long2ObjectOpenHashMapSerializer that = (Long2ObjectOpenHashMapSerializer) o; + return Objects.equals(elementSerializer, that.elementSerializer); + } + + @Override + public int hashCode() { + return Objects.hash(elementSerializer != null ? elementSerializer.hashCode() : 0); + } + + @Override + public TypeSerializerSnapshot> snapshotConfiguration() { + return new Long2ObjectOpenHashMapSnapshot<>(this); + } + + private static final class Long2ObjectOpenHashMapSnapshot + extends CompositeTypeSerializerSnapshot< + Long2ObjectOpenHashMap, Long2ObjectOpenHashMapSerializer> { + + private static final int CURRENT_VERSION = 1; + + public Long2ObjectOpenHashMapSnapshot() { + super(Long2ObjectOpenHashMapSerializer.class); + } + + public Long2ObjectOpenHashMapSnapshot(Long2ObjectOpenHashMapSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers( + Long2ObjectOpenHashMapSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.elementSerializer}; + } + + @Override + protected Long2ObjectOpenHashMapSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + TypeSerializer elementSerializer = (TypeSerializer) nestedSerializers[0]; + return new Long2ObjectOpenHashMapSerializer<>(elementSerializer); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java new file mode 100644 index 000000000..d80079cfc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; + +import java.util.Objects; + +/** + * TypeInformation for {@link Long2ObjectOpenHashMap}. + * + * @param The type of elements in the Long2ObjectOpenHashMap. + */ +public class Long2ObjectOpenHashMapTypeInfo extends TypeInformation> { + + private final TypeInformation elementTypeInfo; + + public Long2ObjectOpenHashMapTypeInfo(TypeInformation elementTypeInfo) { + this.elementTypeInfo = elementTypeInfo; + } + + public TypeInformation getElementTypeInfo() { + return elementTypeInfo; + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + public int getTotalFields() { + return 1; + } + + @Override + public Class> getTypeClass() { + return (Class) Long2ObjectOpenHashMap.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + return new Long2ObjectOpenHashMapSerializer<>(elementTypeInfo.createSerializer(config)); + } + + @Override + public String toString() { + return "Long2ObjectOpenHashMap Type"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + Long2ObjectOpenHashMapTypeInfo that = (Long2ObjectOpenHashMapTypeInfo) obj; + return Objects.equals(elementTypeInfo, that.elementTypeInfo); + } + + @Override + public int hashCode() { + return Objects.hash(elementTypeInfo != null ? elementTypeInfo.hashCode() : 0); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof Long2ObjectOpenHashMapTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java new file mode 100644 index 000000000..dbb4dd3ce --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.updater; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to update and retrieve model data. + * + *

Note that model updater should also ensure that model data is robust to failures, by writing + * model data to snapshots. + * + * @param data type of model. + */ +public interface ModelUpdater extends Serializable { + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void update(long[] keys, double[] values); + + /** Retrieves the model data of the given keys. */ + double[] get(long[] keys); + + /** + * Returns model segments. The model segments are continuously updated/retrieved by + * push/pull(i.e., {@link ModelUpdater#update(long[], double[])} and {@link + * ModelUpdater#get(long[])}). + */ + Iterator getModelSegments(); + + /** Recovers the model data from state. */ + void initializeState(StateInitializationContext context) throws Exception; + + /** Snapshots the model data to state. */ + void snapshotState(StateSnapshotContext context) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java new file mode 100644 index 000000000..9cba95d0f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.utils; + +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +/** A collector that can only output using {@link OutputTag}. */ +public final class ProxySideOutput { + private final Output output; + + public ProxySideOutput(Output output) { + this.output = output; + } + + public void output(OutputTag outputTag, StreamRecord record) { + Preconditions.checkNotNull(outputTag); + output.collect(outputTag, record); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java new file mode 100644 index 000000000..261dba6a8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.utils; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.ps.Message; +import org.apache.flink.ml.common.ps.ServerOperator; +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.MLSession; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.util.OutputTag; + +import java.util.ArrayList; +import java.util.List; + +/** Utility function to describe iterative training process. */ +public final class TrainingUtils { + /** + * Executes the iterative machine learning logic described in {@link IterationStageList} and + * returns the fitted model data as well as the outputs from worker operator. The outputs from + * worker operator are specified via {@link MLSession#getOutputTags()}. + * + * @param inputData the input data. + * @param iterationStages the iterative processing logic. + * @param modelDataType output type information of model data. + * @param modelUpdater the logic to update model on servers. + * @param numServers number of servers. + * @return the fitted model data as well as the outputs from worker operator. The orders are + * {modelData, sideOutputs from workers}. Note that the outputs from workers shares the same + * order with the {@link MLSession#getOutputTags()}. + * @param

type information of input data. + * @param type information of the output model data. + */ + public static DataStreamList train( + DataStream
inputData, + IterationStageList iterationStages, + TypeInformation modelDataType, + ModelUpdater modelUpdater, + int numServers) { + DataStream variableStream = + inputData.getExecutionEnvironment().fromElements(new byte[0]).filter(x -> false); + + return Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(variableStream), + ReplayableDataStreamList.notReplay(inputData), + IterationConfig.newBuilder().build(), + new TrainIterationBody<>(modelUpdater, modelDataType, iterationStages, numServers)); + } + + /** The iteration implementation for training process. */ + private static class TrainIterationBody implements IterationBody { + private final ModelUpdater modelUpdater; + private final TypeInformation modelType; + private final IterationStageList iterationStages; + private final int numServers; + + public TrainIterationBody( + ModelUpdater modelUpdater, + TypeInformation modelType, + IterationStageList iterationStages, + int numServers) { + this.iterationStages = iterationStages; + this.modelType = modelType; + this.modelUpdater = modelUpdater; + this.numServers = numServers; + } + + @Override + @SuppressWarnings("unchecked") + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream variableStream = variableStreams.get(0); + DataStream trainData = dataStreams.get(0); + final OutputTag modelDataOutputTag = new OutputTag<>("MODEL_OUTPUT", modelType); + + SingleOutputStreamOperator messageToServer = + trainData + .connect(variableStream) + .transform( + "WorkerOp", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + new WorkerOperator(iterationStages, numServers)); + int numWorkers = messageToServer.getParallelism(); + + SingleOutputStreamOperator messageToWorker = + messageToServer + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector) + value -> new Message(value).getServerId()) + .transform( + "ServerOp", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + new ServerOperator<>( + iterationStages.stageList, + numWorkers, + modelUpdater, + modelDataOutputTag)); + messageToWorker.setParallelism(numServers); + + DataStream feedback = + messageToWorker + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector) + value -> new Message(value).getWorkerId()) + .map( + (MapFunction) message -> message, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO) + .setParallelism(numWorkers); + + DataStream model = messageToWorker.getSideOutput(modelDataOutputTag); + + List> result = new ArrayList<>(); + result.add(model); + + List> sideOutputTags = iterationStages.session.getOutputTags(); + if (sideOutputTags != null) { + for (OutputTag outputTag : sideOutputTags) { + result.add(messageToServer.getSideOutput(outputTag)); + } + } + + return new IterationBodyResult( + DataStreamList.of(feedback), new DataStreamList(result), null); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/Als.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/Als.java new file mode 100644 index 000000000..5ba788f1b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/Als.java @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple5; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.utils.TrainingUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.TimeWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the Als algorithm. + * + *

ALS tries to decompose a matrix R as R = X * Yt. Here X and Y are called factor matrices. + * Matrix R is usually a sparse matrix representing ratings given from users to items. ALS tries to + * find X and Y that minimize || R - X * Yt ||^2. This is done by iterations. At each step, X is + * fixed and Y is solved, then Y is fixed and X is solved. + * + *

The algorithm is described in "Large-scale Parallel Collaborative Filtering for the Netflix + * Prize, 2007". This algorithm also supports implicit preference model described in "Collaborative + * Filtering for Implicit Feedback Datasets, 2008". + */ +public class Als implements Estimator, AlsParams { + private final Map, Object> paramMap = new HashMap<>(); + private static final int THRESHOLD = 100000; + + public Als() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public AlsModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + final String ratingCol = getRatingCol(); + + DataStream trainData = tEnv.toDataStream(inputs[0]); + + DataStream> alsInput = + trainData + .map( + (MapFunction>) + value -> { + Long user = value.getFieldAs(userCol); + Long item = value.getFieldAs(itemCol); + user = 2L * user; + item = 2L * item + 1L; + Number rating = + ratingCol == null + ? 0.0F + : value.getFieldAs(ratingCol); + + return new Tuple3<>(user, item, rating.doubleValue()); + }) + .name("generateInputALsData") + .returns(Types.TUPLE(Types.LONG, Types.LONG, Types.DOUBLE)); + + /* Initializes variables before iteration. */ + DataStream ratingData = initRatings(alsInput); + int parallelism = ratingData.getParallelism(); + AlsMLSession mlSession = new AlsMLSession(getImplicitPrefs(), getRank(), parallelism); + ExecutionConfig executionConfig = ratingData.getExecutionConfig(); + TypeSerializer typeSerializer = + TypeInformation.of(double[].class).createSerializer(executionConfig); + + IterationStageList iterationStages = + constructIterationStage(mlSession, typeSerializer); + + AlsModelUpdater updater = new AlsModelUpdater(getRank()); + DataStreamList resultList = + TrainingUtils.train( + ratingData, + iterationStages, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + updater, + Math.max(1, parallelism / 2)); + + DataStream> returnData = resultList.get(0); + + DataStream modelData = + returnData + .transform( + "generateModelData", + TypeInformation.of(AlsModelData.class), + new GenerateModelData()) + .name("generateModelData"); + + AlsModel model = new AlsModel().setModelData(tEnv.fromDataStream(modelData)); + ParamUtils.updateExistingParams(model, paramMap); + return model; + } + + private IterationStageList constructIterationStage( + AlsMLSession mlSession, TypeSerializer typeSerializer) { + IterationStageList iterationStages = new IterationStageList<>(mlSession); + if (getImplicitPrefs()) { + /* + * If using implicit prefs, the whole yty matrix must be computed by all reduce stage. + */ + iterationStages + .addStage(new ComputeYtyIndices()) + .addStage( + new PullStage( + () -> mlSession.pullIndices, + () -> mlSession.aggregatorSDAArray, + new YtyAggregator())) + .addStage(new CopyAllReduceData(getRank())) + .addStage( + new AllReduceStage<>( + () -> mlSession.allReduceBuffer, + () -> mlSession.allReduceBuffer, + (ReduceFunction) Als::sumYty, + typeSerializer, + 1)); + } + + iterationStages + .addStage(new ComputeNeighborIndices(getRank())) + .addStage(new PullStage(() -> mlSession.pullIndices, () -> mlSession.pullValues)) + .addStage( + new UpdateCommonFactors( + getRank(), + getImplicitPrefs(), + getNonNegative(), + getRegParam(), + getAlpha())) + .addStage(new PushStage(() -> mlSession.pushIndices, () -> mlSession.pushValues)); + + iterationStages + .addStage( + new ComputeLsMatrixVector( + getRank(), getImplicitPrefs(), getRegParam(), getAlpha())) + .addStage(new PushStage(() -> mlSession.pushIndices, () -> mlSession.pushValues)) + .addStage(new PullStage(() -> mlSession.pullIndices, () -> mlSession.pullValues)) + .addStage(new UpdateHotPointFactors(getRank(), getNonNegative())) + .addStage(new PushStage(() -> mlSession.pushIndices, () -> mlSession.pushValues)); + + iterationStages.setTerminationCriteria( + (SerializableFunction) + o -> o.iterationId / (o.numItemBlocks + o.numUserBlocks) >= getMaxIter()); + return iterationStages; + } + + /** Generates the ModelData from the results of iteration. */ + private static class GenerateModelData extends AbstractStreamOperator + implements OneInputStreamOperator, AlsModelData>, + BoundedOneInput { + + private final List> userFactors = new ArrayList<>(); + private final List> itemFactors = new ArrayList<>(); + + @Override + public void endInput() throws Exception { + LOG.info("Generates model ... " + System.currentTimeMillis()); + output.collect(new StreamRecord<>(new AlsModelData(userFactors, itemFactors))); + } + + @Override + public void processElement(StreamRecord> streamRecord) + throws Exception { + Tuple2 t2 = streamRecord.getValue(); + + if (t2.f0 % 2L == 1L) { + long id = (t2.f0 - 1) / 2L; + float[] factor = new float[t2.f1.length]; + for (int i = 0; i < factor.length; ++i) { + factor[i] = (float) t2.f1[i]; + } + itemFactors.add(Tuple2.of(id, factor)); + } else { + long id = t2.f0 / 2L; + float[] factor = new float[t2.f1.length]; + for (int i = 0; i < factor.length; ++i) { + factor[i] = (float) t2.f1[i]; + } + userFactors.add(Tuple2.of(id, factor)); + } + } + } + + /** + * Initializes the ratings data with the input graph. + * + * @param alsInput The input graph. + * @return The ratings data. + */ + private DataStream initRatings(DataStream> alsInput) { + + DataStream ratings = + alsInput.flatMap( + new RichFlatMapFunction< + Tuple3, Tuple3>() { + + @Override + public void flatMap( + Tuple3 value, + Collector> out) { + out.collect(Tuple3.of(value.f0, value.f1, value.f2)); + out.collect(Tuple3.of(value.f1, value.f0, value.f2)); + } + }) + .keyBy((KeySelector, Long>) value -> value.f0) + .window(EndOfStreamWindows.get()) + .process( + new ProcessWindowFunction< + Tuple3, Ratings, Long, TimeWindow>() { + + @Override + public void process( + Long o, + Context context, + Iterable> iterable, + Collector collector) { + long srcNodeId = -1L; + List> neighbors = new ArrayList<>(); + + for (Tuple3 t4 : iterable) { + srcNodeId = t4.f0; + neighbors.add(Tuple2.of(t4.f1, t4.f2)); + } + if (neighbors.size() > THRESHOLD) { + int numBlock = + neighbors.size() / THRESHOLD + + (neighbors.size() % THRESHOLD == 0L + ? 0 + : 1); + int blockSize = neighbors.size() / numBlock; + int startIndex = 0; + for (int i = 0; i < numBlock; ++i) { + Ratings tmpRating = new Ratings(); + int offset = + Math.min( + i + 1, neighbors.size() % numBlock); + int endIndex = + Math.min( + neighbors.size(), + (i + 1) * blockSize + offset); + int size = endIndex - startIndex; + tmpRating.neighbors = new long[size]; + tmpRating.scores = new double[size]; + for (int j = 0; j < size; j++) { + tmpRating.neighbors[j] = + neighbors.get(startIndex + j).f0; + tmpRating.scores[j] = + neighbors.get(startIndex + j).f1; + } + startIndex = endIndex; + tmpRating.nodeId = srcNodeId; + tmpRating.isMainNode = (i == 0); + tmpRating.isSplit = true; + tmpRating.numNeighbors = neighbors.size(); + collector.collect(tmpRating); + } + } else { + Ratings returnRatings = new Ratings(); + returnRatings.nodeId = srcNodeId; + returnRatings.neighbors = new long[neighbors.size()]; + returnRatings.scores = new double[neighbors.size()]; + returnRatings.isSplit = false; + returnRatings.numNeighbors = neighbors.size(); + returnRatings.isMainNode = false; + for (int i = 0; + i < returnRatings.neighbors.length; + i++) { + returnRatings.neighbors[i] = neighbors.get(i).f0; + returnRatings.scores[i] = neighbors.get(i).f1; + } + collector.collect(returnRatings); + } + } + }) + .returns(GenericTypeInfo.of(Ratings.class)) + .name("initRatings") + .rebalance(); + DataStream profile = generateDataProfile(ratings).broadcast(); + return ratings.union(profile); + } + + private DataStream generateDataProfile(DataStream ratingData) { + DataStream> localSummary = + DataStreamUtils.mapPartition( + ratingData, + new MapPartitionFunction< + Ratings, Tuple5>() { + private static final long serialVersionUID = -3529850335007040435L; + + @Override + public void mapPartition( + Iterable values, + Collector> out) { + long numUsers = 0L; + long numItems = 0L; + long numRatings = 0L; + int hottestUserPoint = 0; + int hottestItemPoint = 0; + for (Ratings ratings : values) { + if (ratings.nodeId % 2L == 0L) { + numUsers++; + numRatings += ratings.scores.length; + hottestUserPoint = + Math.max(hottestUserPoint, ratings.numNeighbors); + } else { + numItems++; + hottestItemPoint = + Math.max(hottestItemPoint, ratings.numNeighbors); + } + } + out.collect( + Tuple5.of( + numUsers, + numItems, + numRatings, + hottestUserPoint, + hottestItemPoint)); + } + }); + + return DataStreamUtils.reduce( + localSummary, + new ReduceFunction>() { + private static final long serialVersionUID = 3849683380245684843L; + + @Override + public Tuple5 reduce( + Tuple5 value1, + Tuple5 value2) { + value1.f0 += value2.f0; + value1.f1 += value2.f1; + value1.f2 += value2.f2; + value1.f3 = Math.max(value1.f3, value2.f3); + value1.f4 = Math.max(value1.f4, value2.f4); + return value1; + } + }) + .map( + new RichMapFunction, Ratings>() { + private static final long serialVersionUID = -2224348217053561771L; + + @Override + public Ratings map(Tuple5 value) { + Ratings profile = new Ratings(); + profile.neighbors = + new long[] { + value.f0, value.f1, value.f2, value.f3, value.f4 + }; + profile.scores = null; + return profile; + } + }) + .name("data_profiling"); + } + + private static class YtyAggregator implements PullStage.Aggregator { + @Override + public double[] add(double[] in, double[] acc) { + + if (acc == null) { + acc = new double[in.length * in.length]; + } + calcYty(in, acc); + return acc; + } + + @Override + public double[] merge(double[] acc1, double[] acc2) { + for (int i = 0; i < acc1.length; i++) { + acc2[i] += acc1[i]; + } + return acc2; + } + + private void calcYty(double[] vec, double[] result) { + for (int i = 0; i < vec.length; i++) { + for (int j = 0; j < vec.length; j++) { + result[i * vec.length + j] += vec[i] * vec[j]; + } + } + } + } + + private static double[][] sumYty(double[][] d1, double[][] d2) { + Preconditions.checkArgument(d1[0].length == d2[0].length); + for (int i = 0; i < d1[0].length; i++) { + d2[0][i] += d1[0][i]; + } + return d2; + } + + /** The whole ratings of a user or an item. */ + public static class Ratings { + + public Ratings() {} + + /** Current node is a split node or not. */ + public boolean isSplit; + + /** Current node is a main node in split nodes or not. */ + public boolean isMainNode; + + /** UserId or itemId decided by identity. */ + public long nodeId; + + /** Number of neighbors. */ + public int numNeighbors; + + /** Neighbors of this nodeId. */ + public long[] neighbors; + + /** Scores from neighbors to this nodeId. */ + public double[] scores; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Als load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsMLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsMLSession.java new file mode 100644 index 000000000..78b147836 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsMLSession.java @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.ps.iterations.MLSessionImpl; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.ml.recommendation.als.Als.Ratings; + +import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** The ML session for als training with PS. */ +public class AlsMLSession extends MLSessionImpl { + + /** Indices for pulling the data. */ + public SharedLongArray pullIndices = new SharedLongArray(); + + /** Values for pulling the data. */ + public SharedDoubleArray pullValues = new SharedDoubleArray(); + + /** Indices for pushing the data. */ + public SharedLongArray pushIndices = new SharedLongArray(); + + /** Values for pushing the data. */ + public SharedDoubleArray pushValues = new SharedDoubleArray(); + + /** The all reduce buffer for computing yty. */ + public double[][] allReduceBuffer; + + /** The aggregator array for computing yty. */ + public SharedDoubleArray aggregatorSDAArray; + + /** Ratings data for current iteration. */ + public BlockData batchData; + + /** The intermediate variable for updating factors. */ + public double[] yty; + + /** Ratings batch data list for user. */ + private List userRatingsList; + + /** Ratings batch data list for item. */ + private List itemRatingsList; + + /** Num blocks of user ratings data. */ + public int numUserBlocks; + + /** Num blocks of item ratings data. */ + public int numItemBlocks; + + /** Processing user batch index in current iteration. */ + public int currentUserIndex; + + /** Processing item batch index in current iteration. */ + public int currentItemIndex; + + /** Current iteration, updates user factors or item factors. */ + public boolean updateUserFactors = true; + + /** Initialized Rating or not. */ + public boolean isRatingsInitialized = false; + + public static final Logger LOG = LoggerFactory.getLogger(Als.class); + + public LongOpenHashSet reusedNeighborsSet; + public Long2IntOpenHashMap reusedNeighborIndexMapping; + + private final int parallelism; + private long timing = 0; + private final boolean implicit; + + public long[] userIds; + public long[] itemIds; + + private int maxNumNeighbors = 0; + private int maxNumNodes = 0; + private final int rank; + + public AlsMLSession(boolean implicit, int rank, int parallelism) { + this.rank = rank; + this.parallelism = parallelism; + this.implicit = implicit; + if (implicit) { + aggregatorSDAArray = new SharedDoubleArray(new double[rank * rank]); + allReduceBuffer = new double[][] {new double[rank * rank]}; + } + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + this.workerId = workerId; + this.numWorkers = numWorkers; + } + + /** Initializes ratings data in minibatch list format. */ + public void initializeRatingsBatchData() throws IOException { + final int taskCapacity = 8 * 1024 * 1024; + final int defaultNumBlocks = 1; + + long numUser = 0L; + long numItem = 0L; + while (inputData.hasNext()) { + long numSamples; + Ratings ratings = inputData.next(); + if (ratings.scores == null) { + numUser = ratings.neighbors[0]; + numItem = ratings.neighbors[1]; + numSamples = ratings.neighbors[2]; + long hottestUserPoint = ratings.neighbors[3]; + long hottestItemPoint = ratings.neighbors[4]; + + if (numItem * rank < taskCapacity) { + this.numUserBlocks = defaultNumBlocks; + } else { + this.numUserBlocks = + (int) (numSamples * rank / (parallelism * taskCapacity)) + 1; + } + if (numUser * rank < taskCapacity) { + this.numItemBlocks = defaultNumBlocks; + } else { + this.numItemBlocks = + (int) (numSamples * rank / (parallelism * taskCapacity)) + 1; + } + + LOG.info("rank : " + rank); + + LOG.info("num total users : " + numUser); + LOG.info("num total items : " + numItem); + LOG.info("num total samples : " + numSamples); + + LOG.info("num user blocks : " + numUserBlocks); + LOG.info("num item blocks : " + numItemBlocks); + + LOG.info("hottest user point : " + hottestUserPoint); + LOG.info("hottest item point : " + hottestItemPoint); + break; + } + } + + this.userRatingsList = new ArrayList<>(numUserBlocks); + int userBlockSize = (int) numUser / (numUserBlocks * parallelism); + for (int i = 0; i < numUserBlocks; ++i) { + BlockData blockData = new BlockData(new ArrayList<>(userBlockSize), false); + this.userRatingsList.add(blockData); + } + this.itemRatingsList = new ArrayList<>(numItemBlocks); + int itemBlockSize = (int) numItem / (numItemBlocks * parallelism); + for (int i = 0; i < numItemBlocks; ++i) { + BlockData blockData = new BlockData(new ArrayList<>(itemBlockSize), false); + this.itemRatingsList.add(blockData); + } + + inputData.reset(); + + while (inputData.hasNext()) { + Ratings ratings = inputData.next(); + if (ratings.scores == null) { + continue; + } + if (ratings.nodeId % 2 == 0) { + int blockId = (int) (ratings.nodeId / 2) % numUserBlocks; + this.userRatingsList.get(blockId).add(ratings); + if (!this.userRatingsList.get(blockId).hasHotPoint) { + this.userRatingsList.get(blockId).hasHotPoint = ratings.isSplit; + } + } else { + int blockId = (int) (ratings.nodeId / 2) % numItemBlocks; + this.itemRatingsList.get(blockId).add(ratings); + if (!this.itemRatingsList.get(blockId).hasHotPoint) { + this.itemRatingsList.get(blockId).hasHotPoint = ratings.isSplit; + } + } + } + + for (BlockData blockData : userRatingsList) { + initializeBlockData(blockData); + } + + for (BlockData blockData : itemRatingsList) { + initializeBlockData(blockData); + } + + pullIndices.size(maxNumNeighbors); + pullValues.size(maxNumNeighbors * rank); + pushIndices.size(maxNumNodes); + pushValues.size(maxNumNodes * rank); + this.reusedNeighborIndexMapping = new Long2IntOpenHashMap(maxNumNeighbors); + this.reusedNeighborsSet = new LongOpenHashSet(maxNumNeighbors); + + if (this.implicit) { + LongOpenHashSet longOpenHashSet = new LongOpenHashSet(); + for (BlockData blockData : itemRatingsList) { + for (Ratings r : blockData.ratingsList) { + if (r.isMainNode && r.isSplit) { + longOpenHashSet.add(r.nodeId); + } else if (!r.isSplit) { + longOpenHashSet.add(r.nodeId); + } + } + } + itemIds = new long[longOpenHashSet.size()]; + Iterator iterator = longOpenHashSet.iterator(); + int it = 0; + while (iterator.hasNext()) { + itemIds[it++] = iterator.next(); + } + longOpenHashSet.clear(); + for (BlockData blockData : userRatingsList) { + for (Ratings r : blockData.ratingsList) { + if (r.isMainNode && r.isSplit) { + longOpenHashSet.add(r.nodeId); + } else if (!r.isSplit) { + longOpenHashSet.add(r.nodeId); + } + } + } + userIds = new long[longOpenHashSet.size()]; + iterator = longOpenHashSet.iterator(); + it = 0; + while (iterator.hasNext()) { + userIds[it++] = iterator.next(); + } + } + } + + private void initializeBlockData(BlockData blockData) { + LongOpenHashSet neighborsSet = new LongOpenHashSet(blockData.ratingsList.size()); + + for (Ratings dataPoint : blockData.ratingsList) { + for (long index : dataPoint.neighbors) { + neighborsSet.add(index); + } + if (!dataPoint.isSplit) { + blockData.numCommonNodeIds++; + } else { + if (dataPoint.isMainNode) { + blockData.numSplitNodeIds++; + } + } + } + maxNumNeighbors = Math.max(maxNumNeighbors, neighborsSet.size()); + maxNumNodes = Math.max(maxNumNodes, blockData.numCommonNodeIds); + } + + public void prepareNextRatingsBatchData() throws IOException { + if (!isRatingsInitialized) { + initializeRatingsBatchData(); + isRatingsInitialized = true; + } + + if (updateUserFactors) { + this.batchData = userRatingsList.get(currentUserIndex++); + if (currentUserIndex == numUserBlocks) { + currentUserIndex = 0; + updateUserFactors = false; + } + } else { + this.batchData = itemRatingsList.get(currentItemIndex++); + if (currentItemIndex == numItemBlocks) { + currentItemIndex = 0; + updateUserFactors = true; + } + } + } + + private long clock() { + long current = System.currentTimeMillis(); + long duration = current - timing; + timing = current; + return duration; + } + + public void log(String className, boolean start) { + if (start) { + LOG.info( + String.format( + "[Worker-%d, iteration-%d] starts %s, %d%n", + workerId, iterationId, className, clock())); + } else { + LOG.info( + String.format( + "[Worker-%d, iteration-%d] ends %s, %d%n", + workerId, iterationId, className, clock())); + } + } + + /** The computing block data in every iteration. */ + public static class BlockData { + public BlockData(List ratingsList, boolean hasHotPoint) { + this.ratingsList = ratingsList; + this.hasHotPoint = hasHotPoint; + } + + public List ratingsList; + + public boolean hasHotPoint; + + public int numCommonNodeIds; + + public int numSplitNodeIds; + + public Ratings get(int idx) { + return ratingsList.get(idx); + } + + public void add(Ratings ratings) { + ratingsList.add(ratings); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModel.java new file mode 100644 index 000000000..f44f15dda --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModel.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** A Model which predicts data using the model data computed by {@link Als}. */ +public class AlsModel implements Model, AlsModelParams { + + private final Map, Object> paramMap = new HashMap<>(); + + protected Table modelDataTable; + + public AlsModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream inputStream = tEnv.toDataStream(inputs[0]); + + final String broadcastModelKey = "broadcastModelKey"; + DataStream modelDataStream = AlsModelData.getModelDataStream(modelDataTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + DataStream inputData = inputList.get(0); + return inputData.map( + new PredictLabelFunction( + broadcastModelKey, getUserCol(), getItemCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + @Override + public AlsModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + AlsModelData.getModelDataStream(modelDataTable), + path, + new AlsModelData.ModelDataEncoder()); + } + + public static AlsModel load(StreamTableEnvironment tEnv, String path) throws IOException { + AlsModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData(tEnv, path, new AlsModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends RichMapFunction { + + private final String broadcastModelKey; + + private final String userCol; + private final String itemCol; + + public Map userFactors; + public Map itemFactors; + + public PredictLabelFunction(String broadcastModelKey, String userCol, String itemCol) { + this.broadcastModelKey = broadcastModelKey; + this.userCol = userCol; + this.itemCol = itemCol; + } + + @Override + public Row map(Row dataPoint) { + if (userFactors == null) { + List modelData = + getRuntimeContext().getBroadcastVariable(broadcastModelKey); + + List> uFactors = new ArrayList<>(); + List> iFactors = new ArrayList<>(); + for (AlsModelData data : modelData) { + uFactors.addAll(data.userFactors); + iFactors.addAll(data.itemFactors); + } + this.userFactors = new HashMap<>(uFactors.size()); + this.itemFactors = new HashMap<>(iFactors.size()); + for (Tuple2 t2 : uFactors) { + double[] values = new double[t2.f1.length]; + for (int i = 0; i < values.length; ++i) { + values[i] = t2.f1[i]; + } + this.userFactors.put(t2.f0, new DenseVector(values)); + } + for (Tuple2 t2 : iFactors) { + double[] values = new double[t2.f1.length]; + for (int i = 0; i < values.length; ++i) { + values[i] = t2.f1[i]; + } + this.itemFactors.put(t2.f0, new DenseVector(values)); + } + } + + Row predictionResult = + predictRating(dataPoint.getFieldAs(userCol), dataPoint.getFieldAs(itemCol)); + return Row.join(dataPoint, predictionResult); + } + + private Row predictRating(long userId, long itemId) { + DenseVector userFeat = userFactors.get(userId); + DenseVector itemFeat = itemFactors.get(itemId); + if (userFeat != null && itemFeat != null) { + return Row.of(BLAS.dot(userFeat, itemFeat)); + } else { + return Row.of(Double.NaN); + } + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelData.java new file mode 100644 index 000000000..60a254566 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelData.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Model data of {@link AlsModel}. + * + *

This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class AlsModelData { + + public List> userFactors; + public List> itemFactors; + + public AlsModelData( + List> userFactors, List> itemFactors) { + this.userFactors = userFactors; + this.itemFactors = itemFactors; + } + + public AlsModelData(AlsModelData modelData) { + this.userFactors = modelData.userFactors; + this.itemFactors = modelData.itemFactors; + } + + /** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ + public static DataStream getModelDataStream(Table modelData) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); + return tEnv.toDataStream(modelData).map(AlsModelData::parseModel); + } + + private static AlsModelData parseModel(Row modelRow) { + return new AlsModelData(modelRow.getFieldAs(0)); + } + + /** Data encoder for {@link AlsModel}. */ + public static class ModelDataEncoder implements Encoder { + + @Override + public void encode(AlsModelData modelData, OutputStream outputStream) throws IOException { + + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); + dataOutputView.writeInt(modelData.userFactors.size()); + if (modelData.userFactors.size() > 0) { + dataOutputView.writeInt(modelData.userFactors.get(0).f1.length); + for (int i = 0; i < modelData.userFactors.size(); ++i) { + dataOutputView.writeLong(modelData.userFactors.get(i).f0); + float[] values = modelData.userFactors.get(i).f1; + for (float value : values) { + dataOutputView.writeFloat(value); + } + } + } + dataOutputView.writeInt(modelData.itemFactors.size()); + if (modelData.itemFactors.size() > 0) { + dataOutputView.writeInt(modelData.itemFactors.get(0).f1.length); + for (int i = 0; i < modelData.itemFactors.size(); ++i) { + dataOutputView.writeLong(modelData.itemFactors.get(i).f0); + float[] values = modelData.itemFactors.get(i).f1; + for (float value : values) { + dataOutputView.writeFloat(value); + } + } + } + } + } + + /** Data decoder for {@link AlsModel}. */ + public static class ModelDataDecoder extends SimpleStreamFormat { + + @Override + public Reader createReader( + Configuration configuration, FSDataInputStream inputStream) { + return new Reader() { + + @Override + public AlsModelData read() throws IOException { + try { + DataInputViewStreamWrapper inputViewStreamWrapper = + new DataInputViewStreamWrapper(inputStream); + int sizeUser = inputViewStreamWrapper.readInt(); + List> userFactors = new ArrayList<>(sizeUser); + + if (sizeUser > 0) { + int rank = inputViewStreamWrapper.readInt(); + for (int i = 0; i < sizeUser; ++i) { + long id = inputViewStreamWrapper.readLong(); + float[] factors = new float[rank]; + for (int j = 0; j < rank; ++j) { + factors[j] = inputViewStreamWrapper.readFloat(); + } + userFactors.add(Tuple2.of(id, factors)); + } + } + int sizeItem = inputViewStreamWrapper.readInt(); + List> itemFactors = new ArrayList<>(sizeItem); + if (sizeItem > 0) { + int rank = inputViewStreamWrapper.readInt(); + for (int i = 0; i < sizeItem; ++i) { + long id = inputViewStreamWrapper.readLong(); + float[] factors = new float[rank]; + for (int j = 0; j < rank; ++j) { + factors[j] = inputViewStreamWrapper.readFloat(); + } + itemFactors.add(Tuple2.of(id, factors)); + } + } + return new AlsModelData(userFactors, itemFactors); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + inputStream.close(); + } + }; + } + + @Override + public TypeInformation getProducedType() { + return TypeInformation.of(AlsModelData.class); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelParams.java new file mode 100644 index 000000000..411f48e9d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelParams.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Params for {@link AlsModel}. + * + * @param The class type of this instance. + */ +public interface AlsModelParams extends HasPredictionCol { + Param USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + default String getUserCol() { + return get(USER_COL); + } + + default T setUserCol(String value) { + return set(USER_COL, value); + } + + default String getItemCol() { + return get(ITEM_COL); + } + + default T setItemCol(String value) { + return set(ITEM_COL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelUpdater.java new file mode 100644 index 000000000..95691a101 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsModelUpdater.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.typeinfo.Long2ObjectOpenHashMapTypeInfo; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +/** Als model updater supports the API for updating model and getting model. */ +public class AlsModelUpdater implements ModelUpdater> { + private final int rank; + private final Random random = new Random(); + Long2ObjectOpenHashMap model; + private ListState> modelDataState; + + public AlsModelUpdater(int rank) { + this.rank = rank; + } + + @Override + public void update(long[] keys, double[] values) { + int offset = rank * (rank + 1); + for (int i = 0; i < keys.length; ++i) { + if (keys[i] >= 0) { + model.put(keys[i], Arrays.copyOfRange(values, i * rank, (i + 1) * rank)); + } else { + if (keys[i] == Long.MIN_VALUE || keys[i] == Long.MIN_VALUE + 1) { + continue; + } + assert (!model.containsKey(keys[i])); + model.put(keys[i], Arrays.copyOfRange(values, i * offset, (i + 1) * offset)); + } + } + } + + @Override + public double[] get(long[] keys) { + if (keys[0] >= 0) { + double[] values = new double[keys.length * rank]; + for (int i = 0; i < keys.length; i++) { + if (!model.containsKey(keys[i])) { + double[] factor = new double[rank]; + random.setSeed(keys[i]); + for (int j = 0; j < rank; ++j) { + factor[j] = random.nextDouble(); + } + model.put(keys[i], factor); + } + System.arraycopy(model.get(keys[i]), 0, values, i * rank, rank); + } + return values; + } else if (keys[0] == Long.MIN_VALUE) { + return new double[rank]; + } else if (keys[0] == Long.MIN_VALUE + 1) { + return new double[rank * rank + rank]; + } else { + int offset = rank * (rank + 1); + double[] values = new double[keys.length * offset]; + for (int i = 0; i < keys.length; i++) { + if (keys[i] == Long.MIN_VALUE) { + continue; + } + System.arraycopy(model.get(keys[i]), 0, values, i * offset, offset); + model.remove(keys[i]); + } + return values; + } + } + + @Override + public Iterator> getModelSegments() { + List> modelSegments = new ArrayList<>(model.size()); + for (Long key : model.keySet()) { + if (key >= 0L) { + modelSegments.add(Tuple2.of(key, model.get(key))); + } + } + return modelSegments.iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + new Long2ObjectOpenHashMapTypeInfo<>( + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))); + model = + OperatorStateUtils.getUniqueElement(modelDataState, "modelDataState") + .orElse(new Long2ObjectOpenHashMap<>()); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + modelDataState.clear(); + modelDataState.add(model); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsParams.java new file mode 100644 index 000000000..20b7c4774 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/AlsParams.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Params for {@link AlsModel}. + * + * @param The class type of this instance. + */ +public interface AlsParams extends AlsModelParams { + + Param RATING_COL = + new StringParam( + "ratingCol", "Column name for rating.", "rating", ParamValidators.notNull()); + + Param ALPHA = + new DoubleParam( + "alpha", "Alpha for implicit preference.", 1.0, ParamValidators.gtEq(0)); + + Param REG_PARAM = + new DoubleParam("regParam", "Regularization parameter.", 0.1, ParamValidators.gtEq(0.)); + + Param IMPLICIT_PREFS = + new BooleanParam( + "implicitPrefs", + "Whether to use implicit preference.", + false, + ParamValidators.alwaysTrue()); + + Param NON_NEGATIVE = + new BooleanParam( + "nonNegative", + "Whether to use non negative constraint for least squares.", + false, + ParamValidators.alwaysTrue()); + + Param RANK = + new IntParam("rank", "Rank of the factorization.", 10, ParamValidators.gt(0)); + + Param MAX_ITER = + new IntParam("maxIter", "Maximum number of iterations.", 10, ParamValidators.gt(0)); + + default String getRatingCol() { + return get(RATING_COL); + } + + default T setRatingCol(String value) { + return set(RATING_COL, value); + } + + default double getAlpha() { + return get(ALPHA); + } + + default T setAlpha(Double value) { + return set(ALPHA, value); + } + + default double getRegParam() { + return get(REG_PARAM); + } + + default T setRegParam(Double value) { + return set(REG_PARAM, value); + } + + default Boolean getImplicitPrefs() { + return get(IMPLICIT_PREFS); + } + + default T setImplicitPrefs(Boolean value) { + return set(IMPLICIT_PREFS, value); + } + + default Boolean getNonNegative() { + return get(NON_NEGATIVE); + } + + default T setNonNegative(Boolean value) { + return set(NON_NEGATIVE, value); + } + + default int getRank() { + return get(RANK); + } + + default T setRank(int value) { + return set(RANK, value); + } + + default int getMaxIter() { + return get(MAX_ITER); + } + + default T setMaxIter(int value) { + return set(MAX_ITER, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeLsMatrixVector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeLsMatrixVector.java new file mode 100644 index 000000000..34f00fb87 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeLsMatrixVector.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.NormalEquationSolver; +import org.apache.flink.ml.recommendation.als.Als.Ratings; + +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * An iteration stage that uses the pulled model values and batch data to compute the least square + * matrices and vectors. + */ +public class ComputeLsMatrixVector extends ProcessStage { + private final int rank; + private final int matrixOffset; + private final boolean implicit; + private final double lambda; + private final double alpha; + + public ComputeLsMatrixVector(int rank, boolean implicit, double lambda, double alpha) { + this.rank = rank; + this.matrixOffset = rank * rank + rank; + this.implicit = implicit; + this.lambda = lambda; + this.alpha = alpha; + } + + @Override + public void process(AlsMLSession session) throws IOException { + session.log(getClass().getSimpleName(), true); + if (session.allReduceBuffer != null + && BLAS.norm2(new DenseVector(session.allReduceBuffer[0])) != 0.0) { + if (session.yty == null) { + session.yty = new double[rank * rank]; + } + System.arraycopy(session.allReduceBuffer[0], 0, session.yty, 0, rank * rank); + } + + List> matrices = + session.batchData.hasHotPoint ? computeMatrices(session, rank) : new ArrayList<>(); + + session.pushIndices.clear(); + session.pushValues.clear(); + session.pullIndices.clear(); + session.pullValues.clear(); + + if (matrices.size() == 0) { + session.pushIndices.add(Long.MIN_VALUE + 1); + session.pushValues.size(matrixOffset); + + session.pullIndices.add(Long.MIN_VALUE + 1); + session.pullValues.size(matrixOffset); + } else { + for (Tuple2 matrix : matrices) { + session.pushIndices.add(matrix.f0); + session.pushValues.addAll(matrix.f1); + } + LongOpenHashSet pullSet = new LongOpenHashSet(); + for (Ratings r : session.batchData.ratingsList) { + if (r.isMainNode) { + pullSet.add(r.nodeId); + } + } + session.pullIndices.size(pullSet.size()); + Iterator iter = pullSet.iterator(); + for (int i = 0; i < pullSet.size(); ++i) { + session.pullIndices.elements()[i] = -iter.next() - 1; + } + session.pullValues.size(session.pullIndices.size() * matrixOffset); + } + + AlsMLSession.LOG.info( + String.format( + "Worker-%d mat vec pull size %d%n", + session.workerId, session.pushIndices.elements()[0])); + + session.log(getClass().getSimpleName(), false); + } + + private List> computeMatrices(AlsMLSession session, int rank) { + NormalEquationSolver ls = new NormalEquationSolver(rank); + List> matvec = new ArrayList<>(); + double[] tmp = new double[rank]; + /* loops over local nodes. */ + for (Ratings ele : session.batchData.ratingsList) { + if (!ele.isSplit) { + continue; + } + double[] ret = new double[matrixOffset]; + /* solves an lease square problem. */ + ls.reset(); + + if (!implicit) { + long[] nb = ele.neighbors; + double[] rating = ele.scores; + for (int i = 0; i < nb.length; i++) { + long index = nb[i]; + int realIndex = session.reusedNeighborIndexMapping.get(index); + System.arraycopy(session.pullValues.elements(), realIndex * rank, tmp, 0, rank); + + ls.add(new DenseVector(tmp), rating[i], 1.0); + } + ls.regularize(nb.length * lambda); + } else { + if (ele.isMainNode) { + ls.merge(new DenseMatrix(rank, rank, session.yty)); + } + int numExplicit = 0; + long[] nb = ele.neighbors; + double[] rating = ele.scores; + + for (int i = 0; i < nb.length; i++) { + long index = nb[i]; + double r = rating[i]; + double c1 = 0.; + + if (r > 0) { + numExplicit++; + c1 = alpha * r; + } + int realIndex = session.reusedNeighborIndexMapping.get(index); + System.arraycopy(session.pullValues.elements(), realIndex * rank, tmp, 0, rank); + + ls.add(new DenseVector(tmp), ((r > 0.0) ? (1.0 + c1) : 0.0), c1); + } + + numExplicit = Math.max(numExplicit, 1); + ls.regularize(numExplicit * lambda); + } + System.arraycopy(ls.getAta().values, 0, ret, 0, rank * rank); + System.arraycopy(ls.getAtb().values, 0, ret, rank * rank, rank); + matvec.add(Tuple2.of(-1 - ele.nodeId, ret)); + } + return matvec; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeNeighborIndices.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeNeighborIndices.java new file mode 100644 index 000000000..2adc3ac0b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeNeighborIndices.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.recommendation.als.Als.Ratings; +import org.apache.flink.ml.recommendation.als.AlsMLSession.BlockData; + +/** An iteration stage that computes the indices needed to update factors. */ +public class ComputeNeighborIndices extends ProcessStage { + + private final int rank; + + public ComputeNeighborIndices(int rank) { + this.rank = rank; + } + + @Override + public void process(AlsMLSession session) throws Exception { + session.log(this.getClass().getSimpleName(), true); + + session.prepareNextRatingsBatchData(); + + BlockData blockData = session.batchData; + + session.reusedNeighborsSet.clear(); + session.reusedNeighborIndexMapping.clear(); + session.pullIndices.clear(); + + for (Ratings ratings : blockData.ratingsList) { + for (long neighbor : ratings.neighbors) { + session.reusedNeighborsSet.add(neighbor); + } + } + + if (session.reusedNeighborsSet.size() == 0) { + session.pullIndices.add(Long.MIN_VALUE); + } else { + int it = 0; + for (Long aLong : session.reusedNeighborsSet) { + session.pullIndices.add(aLong); + session.reusedNeighborIndexMapping.put(aLong.longValue(), it++); + } + } + session.pullValues.size(session.pullIndices.size() * rank); + session.log(this.getClass().getSimpleName(), false); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeYtyIndices.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeYtyIndices.java new file mode 100644 index 000000000..e9c3902f6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/ComputeYtyIndices.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.ps.iterations.ProcessStage; + +/** An iteration stage that calculates the indices for yty matrix computing. */ +public class ComputeYtyIndices extends ProcessStage { + + @Override + public void process(AlsMLSession session) throws Exception { + session.log(this.getClass().getSimpleName(), true); + if (!session.isRatingsInitialized) { + session.initializeRatingsBatchData(); + session.isRatingsInitialized = true; + } + session.pullIndices.clear(); + session.pullValues.clear(); + + if (session.updateUserFactors) { + if (session.itemIds.length == 0 || session.currentItemIndex != 0) { + session.pullIndices.addAll(new long[] {Long.MIN_VALUE}); + } else { + session.pullIndices.addAll(session.itemIds); + } + } else { + if (session.userIds.length == 0 || session.currentUserIndex != 0) { + session.pullIndices.addAll(new long[] {Long.MIN_VALUE}); + } else { + session.pullIndices.addAll(session.userIds); + } + } + session.log(this.getClass().getSimpleName(), false); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/CopyAllReduceData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/CopyAllReduceData.java new file mode 100644 index 000000000..59afd8254 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/CopyAllReduceData.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.ps.iterations.ProcessStage; + +/** An iteration stage that copy the aggregating data to the all reduce data. */ +public class CopyAllReduceData extends ProcessStage { + + private final int rank; + + public CopyAllReduceData(int rank) { + this.rank = rank; + } + + @Override + public void process(AlsMLSession session) throws Exception { + System.arraycopy( + session.aggregatorSDAArray.elements(), + 0, + session.allReduceBuffer[0], + 0, + rank * rank); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/UpdateCommonFactors.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/UpdateCommonFactors.java new file mode 100644 index 000000000..7619cd9ff --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/UpdateCommonFactors.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.NormalEquationSolver; +import org.apache.flink.ml.recommendation.als.Als.Ratings; + +import java.io.IOException; +import java.util.Arrays; + +/** An iteration stage that uses the pulled model values and batch data to compute the factors. */ +public class UpdateCommonFactors extends ProcessStage { + private final int rank; + private final boolean implicit; + private final boolean nonNegative; + private final double lambda; + private final double alpha; + + public UpdateCommonFactors( + int rank, boolean implicit, boolean nonNegative, double lambda, double alpha) { + this.rank = rank; + this.implicit = implicit; + this.nonNegative = nonNegative; + this.lambda = lambda; + this.alpha = alpha; + } + + @Override + public void process(AlsMLSession session) throws IOException { + session.log(this.getClass().getSimpleName(), true); + if (session.allReduceBuffer != null + && BLAS.norm2(new DenseVector(session.allReduceBuffer[0])) > 0.0) { + if (session.yty == null) { + session.yty = new double[rank * rank]; + } + System.arraycopy(session.allReduceBuffer[0], 0, session.yty, 0, rank * rank); + } + session.pushIndices.clear(); + session.pushValues.clear(); + if (session.batchData.numCommonNodeIds == 0) { + session.pushIndices.add(Long.MIN_VALUE); + session.pushValues.size(rank); + return; + } else { + session.pushIndices.size(session.batchData.numCommonNodeIds); + session.pushValues.size(rank * session.pushIndices.size()); + updatedFactorsWithNeighbors(session, rank); + } + + AlsMLSession.LOG.info( + String.format( + "Worker-%d push size %d", session.workerId, session.pushIndices.size())); + session.log(this.getClass().getSimpleName(), false); + } + + private void updatedFactorsWithNeighbors(AlsMLSession session, int rank) { + + NormalEquationSolver ls = new NormalEquationSolver(rank); + + double[] tmpVec = new double[rank]; + DenseVector x = new DenseVector(rank); + + int nonSplitId = 0; + + for (Ratings ele : session.batchData.ratingsList) { + if (ele.isSplit) { + continue; + } + + ls.reset(); + Arrays.fill(x.values, 0); + + if (!implicit) { + long[] nb = ele.neighbors; + double[] rating = ele.scores; + for (int i = 0; i < nb.length; i++) { + long index = nb[i]; + int realIndex = session.reusedNeighborIndexMapping.get(index); + System.arraycopy( + session.pullValues.elements(), realIndex * rank, tmpVec, 0, rank); + ls.add(new DenseVector(tmpVec), rating[i], 1.0); + } + ls.regularize(nb.length * lambda); + ls.solve(x, true); + } else { + ls.merge(new DenseMatrix(rank, rank, session.yty)); + + int numExplicit = 0; + long[] nb = ele.neighbors; + double[] rating = ele.scores; + + for (int i = 0; i < nb.length; i++) { + long index = nb[i]; + double r = rating[i]; + double c1 = 0.; + + if (r > 0) { + numExplicit++; + c1 = alpha * r; + } + + int realIndex = session.reusedNeighborIndexMapping.get(index); + System.arraycopy( + session.pullValues.elements(), realIndex * rank, tmpVec, 0, rank); + + ls.add(new DenseVector(tmpVec), ((r > 0.0) ? (1.0 + c1) : 0.0), c1); + } + + numExplicit = Math.max(numExplicit, 1); + ls.regularize(numExplicit * lambda); + ls.solve(x, nonNegative); + } + session.pushIndices.elements()[nonSplitId] = ele.nodeId; + System.arraycopy(x.values, 0, session.pushValues.elements(), nonSplitId * rank, rank); + nonSplitId++; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/UpdateHotPointFactors.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/UpdateHotPointFactors.java new file mode 100644 index 000000000..933733497 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/als/UpdateHotPointFactors.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.als; + +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.NormalEquationSolver; + +import java.io.IOException; +import java.util.Arrays; + +/** + * An iteration stage that uses the pulled least square matrices and vector data to compute the + * factors. + */ +public class UpdateHotPointFactors extends ProcessStage { + private final boolean nonNegative; + private final int rank; + + public UpdateHotPointFactors(int rank, boolean nonNegative) { + this.rank = rank; + this.nonNegative = nonNegative; + } + + @Override + public void process(AlsMLSession session) throws IOException { + session.log(this.getClass().getSimpleName(), true); + SharedLongArray indices = session.pullIndices; + SharedDoubleArray modelValues = session.pullValues; + session.pushIndices.clear(); + session.pushValues.clear(); + if (session.batchData.numSplitNodeIds != 0) { + session.pushIndices.size(session.batchData.numSplitNodeIds); + session.pushValues.size(session.batchData.numSplitNodeIds * rank); + } else { + session.pushValues.addAll(new double[rank]); + session.pushIndices.add(Long.MIN_VALUE); + return; + } + int offset = rank * (rank + 1); + for (int i = 0; i < indices.size(); ++i) { + if (indices.get(i) == Long.MIN_VALUE) { + session.pushIndices.add(Long.MIN_VALUE); + continue; + } + DenseVector x = new DenseVector(rank); + + DenseMatrix ata = + new DenseMatrix( + rank, + rank, + Arrays.copyOfRange( + modelValues.elements(), i * offset, i * offset + rank * rank)); + DenseVector atb = + new DenseVector( + Arrays.copyOfRange( + modelValues.elements(), + i * offset + rank * rank, + (i + 1) * offset)); + NormalEquationSolver ls = new NormalEquationSolver(rank, ata, atb); + ls.solve(x, nonNegative); + System.arraycopy(x.values, 0, session.pushValues.elements(), i * rank, rank); + + if (session.pullIndices.get(i) != Long.MIN_VALUE) { + session.pushIndices.elements()[i] = -session.pullIndices.elements()[i] - 1; + } else { + session.pushIndices.elements()[i] = session.pullIndices.elements()[i]; + } + } + AlsMLSession.LOG.info( + String.format( + "Worker-%d hot point push size %d", + session.workerId, session.pushIndices.size())); + session.log(this.getClass().getSimpleName(), false); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MessageTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MessageTest.java new file mode 100644 index 000000000..54a2412cc --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MessageTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link org.apache.flink.ml.common.ps.Message}. */ +public class MessageTest { + private Message messageFromBytes; + private Message messageFromArray; + private Message messageFromPojo; + + private TypeSerializer mockPojoTypeSerializer; + + @Before + public void before() throws IOException { + messageFromArray = new Message(1, 0, 1, new long[] {1, 2}, new double[] {1, 2, 3, 4}); + messageFromBytes = new Message(messageFromArray.bytes.clone()); + mockPojoTypeSerializer = Types.POJO(MockPojo.class).createSerializer(new ExecutionConfig()); + messageFromPojo = + new Message( + 1, + 0, + 1, + new long[] {1, 2}, + new MockPojo[] {new MockPojo(1, 1), new MockPojo(2, 2)}, + mockPojoTypeSerializer); + } + + @Test + public void getKeys() { + long[] expectedKeys = new long[] {1, 2}; + assertArrayEquals(expectedKeys, messageFromArray.getKeys()); + assertArrayEquals(expectedKeys, messageFromBytes.getKeys()); + assertArrayEquals(expectedKeys, messageFromPojo.getKeys()); + } + + @Test + public void getValuesInDoubleArray() { + double[] expectedDoubleArray = new double[] {1, 2, 3, 4}; + assertArrayEquals(expectedDoubleArray, messageFromArray.getValuesInDoubleArray(), 1e-7); + assertArrayEquals(expectedDoubleArray, messageFromBytes.getValuesInDoubleArray(), 1e-7); + } + + @Test + public void getValues() throws IOException { + MockPojo[] expectedPojos = new MockPojo[] {new MockPojo(1, 1), new MockPojo(2, 2)}; + assertArrayEquals(expectedPojos, messageFromPojo.getValues(mockPojoTypeSerializer)); + } + + @Test + public void getWorkerId() { + int expectedWorkerId = 1; + assertEquals(expectedWorkerId, messageFromArray.getWorkerId()); + assertEquals(expectedWorkerId, messageFromBytes.getWorkerId()); + assertEquals(expectedWorkerId, messageFromPojo.getWorkerId()); + } + + @Test + public void setWorkerId() { + messageFromArray.setWorkerId(2); + messageFromBytes.setWorkerId(2); + messageFromPojo.setWorkerId(2); + int expectedWorkerId = 2; + assertEquals(expectedWorkerId, messageFromArray.getWorkerId()); + assertEquals(expectedWorkerId, messageFromBytes.getWorkerId()); + assertEquals(expectedWorkerId, messageFromPojo.getWorkerId()); + } + + @Test + public void getServerId() { + int expectedServerId = 0; + assertEquals(expectedServerId, messageFromArray.getServerId()); + assertEquals(expectedServerId, messageFromBytes.getServerId()); + assertEquals(expectedServerId, messageFromPojo.getServerId()); + } + + @Test + public void setServerId() { + messageFromArray.setServerId(2); + messageFromBytes.setServerId(2); + messageFromPojo.setServerId(2); + int expectedServerId = 2; + assertEquals(expectedServerId, messageFromArray.getServerId()); + assertEquals(expectedServerId, messageFromBytes.getServerId()); + assertEquals(expectedServerId, messageFromPojo.getServerId()); + } + + @Test + public void getStagedId() { + int expectedStageId = 1; + assertEquals(expectedStageId, messageFromArray.getStageId()); + assertEquals(expectedStageId, messageFromBytes.getStageId()); + assertEquals(expectedStageId, messageFromPojo.getStageId()); + } + + @Test + public void assembleMessages() { + int numServers = 4; + Message[] messages = new Message[numServers]; + for (int i = 0; i < numServers; i++) { + messages[i] = + new Message( + 1, + i, + 0, + new long[] {i * 2, i * 2 + 1}, + new double[] {i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3}); + } + + Iterator bytes = Arrays.stream(messages).map(x -> x.bytes).iterator(); + Message assembledMessage = Message.assembleMessages(bytes); + + assertEquals(-1, assembledMessage.getServerId()); + assertEquals(1, assembledMessage.getWorkerId()); + assertEquals(0, assembledMessage.getStageId()); + + long[] expectedKeys = new long[numServers * 2]; + for (int i = 0; i < expectedKeys.length; i++) { + expectedKeys[i] = i; + } + assertArrayEquals(expectedKeys, assembledMessage.getKeys()); + + double[] expectedValues = new double[numServers * 4]; + for (int i = 0; i < expectedValues.length; i++) { + expectedValues[i] = i; + } + assertArrayEquals(expectedValues, assembledMessage.getValuesInDoubleArray(), 1e-7); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MockPojo.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MockPojo.java new file mode 100644 index 000000000..6666e3d78 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MockPojo.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +/** Mock pojo class to test all reduce. */ +public class MockPojo { + public int i; + public int j; + + public MockPojo(int i, int j) { + this.i = i; + this.j = j; + } + + public MockPojo() {} + + @Override + public String toString() { + return i + "-" + j; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof MockPojo) { + MockPojo other = (MockPojo) obj; + return i == other.i && j == other.j; + } + return false; + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/TrainingUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/TrainingUtilsTest.java new file mode 100644 index 000000000..5ee1965d3 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/TrainingUtilsTest.java @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.MLSessionImpl; +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.ml.common.ps.typeinfo.Long2ObjectOpenHashMapTypeInfo; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.common.ps.utils.ProxySideOutput; +import org.apache.flink.ml.common.ps.utils.TrainingUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.test.util.TestBaseUtils; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableSupplier; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.function.Supplier; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link TrainingUtils}. */ +public class TrainingUtilsTest { + private static final int NUM_WORKERS = 2; + private static final int NUM_SERVERS = 6; + private static final int MAX_ITER = 3; + private static final int NUM_DOUBLES_PER_KEY = 2; + private DataStream inputData; + StreamExecutionEnvironment env; + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + env.setParallelism(NUM_WORKERS); + inputData = + env.fromCollection( + Arrays.asList( + Vectors.dense(1, 1, 1, 1), + Vectors.dense(2, 2, 2, 2), + Vectors.dense(3, 3, 3, 3), + Vectors.dense(4, 4, 4, 4))) + .map(x -> x, DenseVectorTypeInfo.INSTANCE); + } + + @Test + public void testPushSumAndPullAgg() throws Exception { + MockSession mockSession = new MockSession(); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage( + new PushStage( + () -> new SharedLongArray(new long[] {1, 4}), + () -> new SharedDoubleArray(new double[] {1, 1, 4, 4}))) + .addStage( + new PullStage( + () -> new SharedLongArray(new long[] {1, 3, 4}), + () -> { + mockSession.pullResult.size(4); + return mockSession.pullResult; + }, + new MockAggregator())) + .addStage( + new ResultChecker( + () -> { + double[] expectedResult = new double[4]; + Arrays.fill( + expectedResult, + (mockSession.iterationId + 1) + * (mockSession.iterationId + 1) + * 68); + return Arrays.equals( + expectedResult, + trimToArray(mockSession.pullResult)); + })) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + + DataStream> modelStream = resultList.get(0); + List> collectedModelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + List> expectedModelData = + Arrays.asList( + Tuple2.of( + 1L, new double[] {NUM_WORKERS * MAX_ITER, NUM_WORKERS * MAX_ITER}), + Tuple2.of(3L, new double[] {0, 0}), + Tuple2.of( + 4L, + new double[] { + NUM_WORKERS * MAX_ITER * 4, NUM_WORKERS * MAX_ITER * 4 + })); + + verifyModelData(expectedModelData, collectedModelData); + } + + @Test + public void testPushMinAndPull() throws Exception { + MockSession mockSession = new MockSession(); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage( + new PushStage( + () -> new SharedLongArray(new long[] {1, 4}), + () -> new SharedDoubleArray(new double[] {1, 1, 4, 4}), + Double::min)) + .addStage( + new PullStage( + () -> new SharedLongArray(new long[] {1, 3}), + () -> { + mockSession.pullResult.size(4); + return mockSession.pullResult; + })) + .addStage( + new ResultChecker( + () -> + Arrays.equals( + new double[] { + mockSession.iterationId + 1, + mockSession.iterationId + 1, + 0, + 0 + }, + trimToArray(mockSession.pullResult)))) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + DataStream> modelStream = resultList.get(0); + List> collectedModelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + List> expectedModelData = + Arrays.asList( + Tuple2.of(1L, new double[] {MAX_ITER, MAX_ITER}), + Tuple2.of(3L, new double[] {0, 0}), + Tuple2.of(4L, new double[] {MAX_ITER * 4, MAX_ITER * 4})); + + verifyModelData(expectedModelData, collectedModelData); + } + + @Test + public void testAllReduce() throws Exception { + ExecutionConfig executionConfig = inputData.getExecutionEnvironment().getConfig(); + int executionInterval = 2; + TypeSerializer mockPojoTypeSerializer = + Types.POJO(MockPojo.class).createSerializer(executionConfig); + MockSession mockSession = new MockSession(); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new MockInitStage()) + .addStage( + new AllReduceStage<>( + () -> mockSession.allReduceInputAndResult, + () -> mockSession.allReduceInputAndResult, + (ReduceFunction) TrainingUtilsTest::sumPojo, + mockPojoTypeSerializer, + executionInterval)) + .addStage( + new ResultChecker( + () -> { + if (mockSession.iterationId % executionInterval == 0) { + MockPojo[] reduceResult = + mockSession.allReduceInputAndResult; + Assert.assertEquals(2, reduceResult.length); + MockPojo expectedPojo = + new MockPojo( + NUM_WORKERS + * (mockSession.iterationId + / executionInterval + + 1), + NUM_WORKERS + * (mockSession.iterationId + / executionInterval + + 1) + * 2); + Assert.assertEquals(expectedPojo, reduceResult[0]); + Assert.assertEquals(expectedPojo, reduceResult[1]); + } + return true; + })) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + DataStream> modelStream = resultList.get(0); + List> modelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + Assert.assertEquals(0, modelData.size()); + } + + @Test + public void testReduceScatter() throws Exception { + ExecutionConfig executionConfig = inputData.getExecutionEnvironment().getConfig(); + int executionInterval = 2; + TypeSerializer mockPojoTypeSerializer = + Types.POJO(MockPojo.class).createSerializer(executionConfig); + MockSession mockSession = + new MockSession( + Collections.singletonList( + new OutputTag<>( + "reduceScatter", + new TupleTypeInfo<>( + Types.INT, + Types.INT, + Types.OBJECT_ARRAY(Types.POJO(MockPojo.class)))))); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new MockInitStage()) + .addStage( + new ReduceScatterStage<>( + () -> mockSession.reduceScatterInput, + () -> mockSession.reduceScatterResult, + new int[] {1, 1}, + (ReduceFunction) TrainingUtilsTest::sumPojo, + mockPojoTypeSerializer, + executionInterval)) + .addStage( + new ResultChecker( + () -> { + if (mockSession.iterationId % executionInterval == 0) { + MockPojo[] reduceResult = + mockSession.reduceScatterResult; + Assert.assertEquals(1, reduceResult.length); + MockPojo expectedPojo = + new MockPojo(NUM_WORKERS, NUM_WORKERS * 2); + Assert.assertEquals(expectedPojo, reduceResult[0]); + } + return true; + })) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + DataStream> modelStream = resultList.get(0); + List> modelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + Assert.assertEquals(0, modelData.size()); + } + + @Test + public void readTrainDataAndOutput() throws Exception { + MockSession mockSession = + new MockSession( + Collections.singletonList( + new OutputTag<>( + "numOfTrainData", + new TupleTypeInfo<>(Types.INT, Types.INT, Types.INT)))); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new ReadDataStage()) + .addStage( + new AllReduceStage<>( + () -> mockSession.numDataScanned, + () -> mockSession.numDataScanned, + TrainingUtilsTest::sumIntArray, + IntSerializer.INSTANCE)) + .addStage(new MockOutputStage<>(() -> mockSession.numDataScanned[0])) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + + DataStream> pulledStream = resultList.get(1); + List> pulls = + IteratorUtils.toList(pulledStream.executeAndCollect()); + + List> expectedPulls = new ArrayList<>(); + int numDataScanned = 4; + for (int i = 0; i < MAX_ITER; i++) { + for (int w = 0; w < NUM_WORKERS; w++) { + expectedPulls.add(Tuple3.of(i, w, numDataScanned)); + } + } + Comparator> comparator = + (o1, o2) -> { + int cmp = Integer.compare(o1.f0, o2.f0); + if (cmp == 0) { + cmp = Integer.compare(o1.f1, o2.f1); + if (cmp == 0) { + cmp = Integer.compare(o1.f2, o2.f2); + } + } + return cmp; + }; + TestBaseUtils.compareResultCollections(expectedPulls, pulls, comparator); + } + + /** The session that one worker can access. */ + private static class MockSession extends MLSessionImpl { + public MockPojo[] allReduceInputAndResult; + public MockPojo[] reduceScatterInput; + public MockPojo[] reduceScatterResult; + public SharedDoubleArray pullResult; + private ProxySideOutput output; + private Integer[] numDataScanned; + + @Override + public void setOutput(ProxySideOutput output) { + this.output = output; + } + + public MockSession(List> outputTags) { + super(outputTags); + pullResult = new SharedDoubleArray(); + this.numDataScanned = new Integer[1]; + } + + public MockSession() { + this(null); + } + } + + /** The model updater on servers. */ + private static class MockModelUpdater implements ModelUpdater> { + private final int numDoublesPerKey; + private Long2ObjectOpenHashMap model; + private ListState> modelDataState; + + public MockModelUpdater(int numDoublesPerKey) { + this.numDoublesPerKey = numDoublesPerKey; + this.model = new Long2ObjectOpenHashMap<>(); + } + + @Override + public void update(long[] keys, double[] values) { + Preconditions.checkState(keys.length * numDoublesPerKey == values.length); + for (int i = 0; i < keys.length; i++) { + long index = keys[i]; + model.putIfAbsent(index, new double[numDoublesPerKey]); + double[] oneDimModel = model.get(index); + for (int j = 0; j < numDoublesPerKey; j++) { + oneDimModel[j] += values[i * numDoublesPerKey + j]; + } + } + } + + @Override + public double[] get(long[] keys) { + double[] values = new double[keys.length * numDoublesPerKey]; + for (int i = 0; i < keys.length; i++) { + long index = keys[i]; + model.putIfAbsent(index, new double[numDoublesPerKey]); + double[] oneDimModel = model.get(index); + for (int j = 0; j < numDoublesPerKey; j++) { + values[i * numDoublesPerKey + j] += oneDimModel[j]; + } + } + return values; + } + + @Override + public Iterator> getModelSegments() { + return model.long2ObjectEntrySet().stream() + .map(x -> Tuple2.of(x.getLongKey(), x.getValue())) + .iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + new Long2ObjectOpenHashMapTypeInfo<>( + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))); + model = + OperatorStateUtils.getUniqueElement(modelDataState, "modelDataState") + .orElse(new Long2ObjectOpenHashMap<>()); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + modelDataState.clear(); + modelDataState.add(model); + } + } + + /** A stage that initialize the value for all-reduce and reduce-scatter. */ + private static class MockInitStage extends ProcessStage { + + @Override + public void process(MockSession session) { + if (session.iterationId == 0) { + session.allReduceInputAndResult = new MockPojo[2]; + session.allReduceInputAndResult[0] = new MockPojo(1, 2); + session.allReduceInputAndResult[1] = new MockPojo(1, 2); + } + + session.reduceScatterInput = new MockPojo[2]; + session.reduceScatterInput[0] = new MockPojo(1, 2); + session.reduceScatterInput[1] = new MockPojo(1, 2); + session.reduceScatterResult = new MockPojo[1]; + } + } + + /** A stage that scans the data and count the number of data points scanned. */ + private static class ReadDataStage extends ProcessStage { + + @Override + public void process(MockSession session) throws Exception { + session.numDataScanned[0] = 0; + while (session.inputData.hasNext()) { + session.inputData.next(); + session.numDataScanned[0]++; + } + session.inputData.reset(); + } + } + + /** A stage that checks the value of some intermediate results. */ + private static class ResultChecker extends ProcessStage { + Supplier checker; + + public ResultChecker(SerializableSupplier checker) { + this.checker = checker; + } + + @Override + public void process(MockSession session) { + Preconditions.checkState(checker.get()); + } + } + + /** A stage that output non-model data to downstream tasks. */ + private static class MockOutputStage extends ProcessStage { + + private final SerializableSupplier outputSupplier; + + public MockOutputStage(SerializableSupplier outputSupplier) { + this.outputSupplier = outputSupplier; + } + + @Override + public void process(MockSession session) { + OutputTag> outputTag = + (OutputTag>) session.getOutputTags().get(0); + session.output.output( + outputTag, + new StreamRecord<>( + Tuple3.of( + session.iterationId, session.workerId, outputSupplier.get()))); + } + } + + /** An aggregator that can be used in a pull request. */ + private static class MockAggregator implements PullStage.Aggregator { + @Override + public double[] add(double[] in, double[] acc) { + if (acc == null) { + acc = new double[in.length * in.length]; + } + + for (int i = 0; i < in.length; i++) { + for (int j = 0; j < in.length; j++) { + acc[i * in.length + j] += in[i] * in[j]; + } + } + return acc; + } + + @Override + public double[] merge(double[] acc1, double[] acc2) { + for (int i = 0; i < acc1.length; i++) { + acc2[i] += acc1[i]; + } + return acc2; + } + } + + private void verifyModelData( + List> expected, List> actual) { + assertEquals(expected.size(), actual.size()); + expected.sort(Comparator.comparingLong(x -> x.f0)); + actual.sort(Comparator.comparingLong(x -> x.f0)); + for (int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i).f0, actual.get(i).f0); + assertArrayEquals(expected.get(i).f1, actual.get(i).f1, 1e-7); + } + } + + private static MockPojo[] sumPojo(MockPojo[] d1, MockPojo[] d2) { + Preconditions.checkArgument(d1.length == d2.length); + for (int i = 0; i < d1.length; i++) { + d2[i].i += d1[i].i; + d2[i].j += d1[i].j; + } + return d2; + } + + private static Integer[] sumIntArray(Integer[] d1, Integer[] d2) { + Preconditions.checkArgument(d1.length == d2.length); + for (int i = 0; i < d1.length; i++) { + d2[i] += d1[i]; + } + return d2; + } + + private static double[] trimToArray(SharedDoubleArray array) { + return Arrays.copyOfRange(array.elements(), 0, array.size()); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/AlsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/AlsTest.java new file mode 100644 index 000000000..fd94f0c97 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/AlsTest.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.recommendation.als.Als; +import org.apache.flink.ml.recommendation.als.AlsModel; +import org.apache.flink.ml.recommendation.als.AlsModelData; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link Als} and {@link AlsModel}. */ +public class AlsTest extends AbstractTestBase { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private final List trainData = + Arrays.asList( + Row.of(1L, 5L, 0.1), + Row.of(2L, 8L, 0.5), + Row.of(3L, 5L, 0.8), + Row.of(4L, 7L, 0.1), + Row.of(1L, 7L, 0.7), + Row.of(2L, 5L, 0.9), + Row.of(3L, 8L, 0.1), + Row.of(2L, 6L, 0.7), + Row.of(2L, 7L, 0.4), + Row.of(1L, 8L, 0.3), + Row.of(4L, 6L, 0.4), + Row.of(3L, 7L, 0.6), + Row.of(1L, 6L, 0.5), + Row.of(4L, 8L, 0.3)); + + private static final double TOLERANCE = 1.0e-7; + private static final float FLOAT_TOLERANCE = 1.0e-6f; + + private final List smallTrainData = + Arrays.asList(Row.of(1L, 5L, 0.7), Row.of(2L, 6L, 0.4)); + + private final List testData = Collections.singletonList(Row.of(1L, 6L)); + + private Table trainDataTable; + private Table smallTrainDataTable; + private Table testDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(2); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + trainDataTable = + tEnv.fromDataStream( + env.fromCollection( + trainData, + new RowTypeInfo( + new TypeInformation[] { + Types.LONG, Types.LONG, Types.DOUBLE + }, + new String[] {"uid", "iid", "rating"}))); + + smallTrainDataTable = + tEnv.fromDataStream( + env.fromCollection( + smallTrainData, + new RowTypeInfo( + new TypeInformation[] { + Types.LONG, Types.LONG, Types.DOUBLE + }, + new String[] {"uid", "iid", "rating"}))); + + testDataTable = + tEnv.fromDataStream( + env.fromCollection( + testData, + new RowTypeInfo( + new TypeInformation[] {Types.LONG, Types.LONG}, + new String[] {"uid", "iid"}))); + } + + @SuppressWarnings("unchecked") + private void verifyPredictionResult(Table output, String predictionCol, double expectedData) + throws Exception { + List predictResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predictResult) { + double prediction = predictionRow.getFieldAs(predictionCol); + assertEquals(prediction, expectedData, TOLERANCE); + } + } + + @Test + public void testParam() { + Als als = new Als(); + assertEquals("user", als.getUserCol()); + assertEquals("item", als.getItemCol()); + assertEquals("rating", als.getRatingCol()); + assertEquals(1.0, als.getAlpha(), TOLERANCE); + assertEquals(0.1, als.getRegParam(), TOLERANCE); + assertEquals(10, als.getRank()); + assertEquals(false, als.getImplicitPrefs()); + assertEquals(false, als.getNonNegative()); + assertEquals(10, als.getMaxIter()); + assertEquals("prediction", als.getPredictionCol()); + + als.setUserCol("userCol") + .setItemCol("itemCol") + .setRatingCol("ratingCol") + .setAlpha(0.001) + .setRegParam(0.5) + .setRank(100) + .setImplicitPrefs(true) + .setNonNegative(false) + .setMaxIter(1000) + .setPredictionCol("predict_result"); + + assertEquals("userCol", als.getUserCol()); + assertEquals("itemCol", als.getItemCol()); + assertEquals("ratingCol", als.getRatingCol()); + assertEquals(0.001, als.getAlpha(), TOLERANCE); + assertEquals(0.5, als.getRegParam(), TOLERANCE); + assertEquals(100, als.getRank()); + assertEquals(true, als.getImplicitPrefs()); + assertEquals(false, als.getNonNegative()); + assertEquals(1000, als.getMaxIter()); + assertEquals("predict_result", als.getPredictionCol()); + } + + @Test + public void testOutputSchema() { + Table tempTable = trainDataTable.as("uid", "iid", "rating_col"); + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating_col") + .setPredictionCol("predict_result"); + AlsModel model = als.fit(trainDataTable); + Table output = model.transform(tempTable)[0]; + assertEquals( + Arrays.asList("uid", "iid", "rating_col", "predict_result"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredictWithImplicit() throws Exception { + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(5) + .setRank(10) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(true) + .setNonNegative(true) + .setPredictionCol("predict_result"); + Table output = als.fit(trainDataTable).transform(testDataTable)[0]; + verifyPredictionResult(output, als.getPredictionCol(), 0.8342121792822439); + } + + @Test + public void testFitAndPredictWithoutNonNegative() throws Exception { + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(5) + .setRank(10) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(true) + .setNonNegative(false) + .setPredictionCol("predict_result"); + Table output = als.fit(trainDataTable).transform(testDataTable)[0]; + verifyPredictionResult(output, als.getPredictionCol(), 0.8345180300543498); + } + + @Test + public void testFitAndPredictWithExplicit() throws Exception { + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(5) + .setRank(10) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(false) + .setNonNegative(true) + .setPredictionCol("predict_result"); + Table output = als.fit(trainDataTable).transform(testDataTable)[0]; + verifyPredictionResult(output, als.getPredictionCol(), 0.37476815535599206); + } + + @Test + public void testSaveLoadAndTransform() throws Exception { + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(10) + .setRank(10) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(false) + .setNonNegative(true) + .setPredictionCol("predict_result"); + AlsModel model = als.fit(trainDataTable); + AlsModel loadModel = + TestUtils.saveAndReload( + tEnv, model, tempFolder.newFolder().getAbsolutePath(), AlsModel::load); + Table output = loadModel.transform(testDataTable)[0]; + verifyPredictionResult(output, als.getPredictionCol(), 0.37558552399494904); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetModelData() throws Exception { + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(10) + .setRank(3) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(true) + .setNonNegative(true) + .setPredictionCol("predict_result"); + Table model = als.fit(trainDataTable).getModelData()[0]; + List modelRows = IteratorUtils.toList(tEnv.toDataStream(model).executeAndCollect()); + for (Row modelRow : modelRows) { + AlsModelData modelData = modelRow.getFieldAs(0); + for (Tuple2 t2 : modelData.userFactors) { + if (t2.f0 == 1L) { + assertArrayEquals( + t2.f1, + new float[] {0.72853327f, 0.33467698f, 0.59506977f}, + FLOAT_TOLERANCE); + } + if (t2.f0 == 2L) { + assertArrayEquals( + t2.f1, + new float[] {0.7278192f, 0.33339077f, 0.60418415f}, + FLOAT_TOLERANCE); + } + if (t2.f0 == 3L) { + assertArrayEquals( + t2.f1, + new float[] {0.15143539f, 0.82475346f, 0.5966393f}, + FLOAT_TOLERANCE); + } + if (t2.f0 == 4L) { + assertArrayEquals( + t2.f1, new float[] {0.9454353f, 0.2567069f, 0.0f}, FLOAT_TOLERANCE); + } + } + + for (Tuple2 t2 : modelData.itemFactors) { + if (t2.f0 == 5L) { + assertArrayEquals( + t2.f1, + new float[] {0.16018498f, 0.42313296f, 0.9295262f}, + FLOAT_TOLERANCE); + } + if (t2.f0 == 6L) { + assertArrayEquals( + t2.f1, new float[] {0.980295f, 0.0f, 0.19405676f}, FLOAT_TOLERANCE); + } + if (t2.f0 == 7L) { + assertArrayEquals( + t2.f1, + new float[] {0.7008933f, 0.5764357f, 0.41699994f}, + FLOAT_TOLERANCE); + } + if (t2.f0 == 8L) { + assertArrayEquals( + t2.f1, + new float[] {0.7045122f, 0.57172996f, 0.41351604f}, + FLOAT_TOLERANCE); + } + } + } + } + + @Test + public void testSetModelData() throws Exception { + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(10) + .setRank(3) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(true) + .setNonNegative(true) + .setPredictionCol("predict_result"); + Table modelData = als.fit(trainDataTable).getModelData()[0]; + + AlsModel model = + new AlsModel() + .setModelData(modelData) + .setUserCol("uid") + .setItemCol("iid") + .setPredictionCol("predict_result"); + + Table output = model.transform(testDataTable)[0]; + verifyPredictionResult(output, als.getPredictionCol(), 0.8296548350206177); + } + + @Test + public void testMoreSubtaskThanData() throws Exception { + env.setParallelism(4); + Als als = + new Als() + .setUserCol("uid") + .setItemCol("iid") + .setRatingCol("rating") + .setMaxIter(5) + .setRank(10) + .setAlpha(0.1) + .setRegParam(0.1) + .setImplicitPrefs(false) + .setNonNegative(true) + .setPredictionCol("predict_result"); + Table output = als.fit(smallTrainDataTable).transform(testDataTable)[0]; + verifyPredictionResult(output, als.getPredictionCol(), 0.3317218226859576); + } +} diff --git a/flink-ml-python/pyflink/examples/ml/recommendation/als_example.py b/flink-ml-python/pyflink/examples/ml/recommendation/als_example.py new file mode 100644 index 000000000..1f5492f20 --- /dev/null +++ b/flink-ml-python/pyflink/examples/ml/recommendation/als_example.py @@ -0,0 +1,82 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Simple program that creates a Swing instance and gives recommendations for items. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.table import StreamTableEnvironment + +from pyflink.ml.recommendation.als import Als + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +input_table = t_env.from_data_stream( + env.from_collection([ + (1, 5, 0.1), + (2, 8, 0.5), + (3, 5, 0.8), + (4, 7, 0.1), + (1, 7, 0.7), + (2, 5, 0.9), + (3, 8, 0.1), + (2, 6, 0.7), + (2, 7, 0.4), + (1, 8, 0.3), + (4, 6, 0.4), + (3, 7, 0.6), + (1, 6, 0.5), + (4, 8, 0.3) + ], + type_info=Types.ROW_NAMED( + ['user', 'item', 'rating'], + [Types.LONG(), Types.LONG(), Types.DOUBLE()]) + )) + +test_table = t_env.from_data_stream( + env.from_collection([ + (1, 6), + (2, 7) + ], + type_info=Types.ROW_NAMED( + ['user', 'item'], + [Types.LONG(), Types.LONG()]) + )) + +# Creates a als object and initialize its parameters. +als = Als() + +# Transforms the data to Als algorithm result. +output_table = als.fit(input_table).transform(test_table)[0] + +# Extracts and display the results. +field_names = output_table[0].get_schema().get_field_names() + +results = t_env.to_data_stream( + output_table[0]).execute_and_collect() + +for result in results: + user = result[field_names.index(als.get_user_col())] + item = result[field_names.index(als.get_item_col())] + score = result[field_names.index(als.get_prediction_col())] + print(f'user: {user}, item : {item}, score: {score}') diff --git a/flink-ml-python/pyflink/ml/recommendation/als.py b/flink-ml-python/pyflink/ml/recommendation/als.py new file mode 100644 index 000000000..4f5c9bef0 --- /dev/null +++ b/flink-ml-python/pyflink/ml/recommendation/als.py @@ -0,0 +1,182 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import typing + +from pyflink.ml.wrapper import JavaWithParams +from pyflink.ml.param import Param, StringParam, IntParam, FloatParam, BooleanParam, ParamValidators +from pyflink.ml.feature.common import JavaFeatureModel, JavaFeatureEstimator +from pyflink.ml.common.param import HasPredictionCol + + +class _AlsParams( + JavaWithParams, + HasPredictionCol +): + """ + Params for :class:`Als`. + """ + USER_COL: Param[str] = StringParam( + "user_col", + "User column name.", + "user", + ParamValidators.not_null()) + + ITEM_COL: Param[str] = StringParam( + "item_col", + "Item column name.", + "item", + ParamValidators.not_null()) + + RATING_COL: Param[str] = StringParam( + "rating_col", + "Rating column name.", + "rating", + ParamValidators.not_null()) + + ALPHA: Param[float] = FloatParam( + "alpha", + "Alpha for implicit preference.", + 1.0, + ParamValidators.gt_eq(0)) + + REG_PARAM: Param[float] = FloatParam( + "reg_param", + "Regularization parameter.", + 0.1, + ParamValidators.gt_eq(0)) + + IMPLICIT_PREFS: Param[bool] = BooleanParam( + "implicit_refs", + "Whether to use implicit preference.", + False) + + NON_NEGATIVE: Param[bool] = BooleanParam( + "non_negative", + "Whether to use non negative constraint for least squares.", + False) + + RANK: Param[int] = IntParam( + "rank", + "Rank of the factorization.", + 10, + ParamValidators.gt(0)) + + MAX_ITER: Param[int] = IntParam( + "max_iter", + "Maximum number of iterations.", + 10, + ParamValidators.gt(0)) + + def __init__(self, java_params): + super(_AlsParams, self).__init__(java_params) + + def set_user_col(self, value: str): + return typing.cast(_AlsParams, self.set(self.USER_COL, value)) + + def get_user_col(self) -> str: + return self.get(self.USER_COL) + + def set_item_col(self, value: str): + return typing.cast(_AlsParams, self.set(self.ITEM_COL, value)) + + def get_item_col(self) -> str: + return self.get(self.ITEM_COL) + + def set_rating_col(self, value: str): + return typing.cast(_AlsParams, self.set(self.RATING_COL, value)) + + def get_rating_col(self) -> str: + return self.get(self.RATING_COL) + + def set_alpha(self, value: float): + return typing.cast(_AlsParams, self.set(self.ALPHA, value)) + + def get_alpha(self) -> float: + return self.get(self.ALPHA) + + def set_reg_param(self, value: float): + return typing.cast(_AlsParams, self.set(self.REG_PARAM, value)) + + def get_reg_param(self) -> float: + return self.get(self.REG_PARAM) + + def set_implicit_refs(self, value: bool): + return typing.cast(_AlsParams, self.set(self.IMPLICIT_PREFS, value)) + + def get_implicit_refs(self) -> bool: + return self.get(self.NON_NEGATIVE) + + def set_non_negative(self, value: bool): + return typing.cast(_AlsParams, self.set(self.IMPLICIT_PREFS, value)) + + def get_non_negative(self) -> bool: + return self.get(self.NON_NEGATIVE) + + def set_rank(self, value: int): + return typing.cast(_AlsParams, self.set(self.RANK, value)) + + def get_rank(self) -> int: + return self.get(self.RANK) + + def set_max_iter(self, value: int): + return typing.cast(_AlsParams, self.set(self.MAX_ITER, value)) + + def get_max_iter(self) -> int: + return self.get(self.MAX_ITER) + + +class AlsModel(JavaFeatureModel, _AlsParams): + """ + A Model which transforms data using the model data computed by :class:`Als`. + """ + + def __init__(self, java_model=None): + super(AlsModel, self).__init__(java_model) + + @classmethod + def _java_model_package_name(cls) -> str: + return "als" + + @classmethod + def _java_model_class_name(cls) -> str: + return "AlsModel" + + +class Als(JavaFeatureEstimator, _AlsParams): + """ + An Estimator which implements the Als algorithm. ALS tries to decompose a matrix + R as R = X * Yt. Here X and Y are called factor matrices. Matrix R is usually a + sparse matrix representing ratings given from users to items. ALS tries to + find X and Y that minimize || R - X * Yt ||^2. This is done by iterations. At each + step, X is fixed and Y is solved, then Y is fixed and X is solved. + """ + + def __init__(self): + super(Als, self).__init__() + + @classmethod + def _create_model(cls, java_model) -> AlsModel: + return AlsModel(java_model) + + @classmethod + def _java_estimator_package_name(cls) -> str: + return "als" + + @classmethod + def _java_estimator_class_name(cls) -> str: + return "Als" diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/test_als.py b/flink-ml-python/pyflink/ml/recommendation/tests/test_als.py new file mode 100644 index 000000000..c42232439 --- /dev/null +++ b/flink-ml-python/pyflink/ml/recommendation/tests/test_als.py @@ -0,0 +1,132 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from pyflink.common import Types +from pyflink.table import Table +from typing import List + +from pyflink.ml.recommendation.als import Als +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase + + +# Tests Als. +class AlsTest(PyFlinkMLTestCase): + def setUp(self): + super(AlsTest, self).setUp() + self.input_table = self.t_env.from_data_stream( + self.env.from_collection([ + (1, 5, 0.1), + (2, 8, 0.5), + (3, 5, 0.8), + (4, 7, 0.1), + (1, 7, 0.7), + (2, 5, 0.9), + (3, 8, 0.1), + (2, 6, 0.7), + (2, 7, 0.4), + (1, 8, 0.3), + (4, 6, 0.4), + (3, 7, 0.6), + (1, 6, 0.5), + (4, 8, 0.3) + ], + type_info=Types.ROW_NAMED( + ['user', 'item', 'rating'], + [Types.LONG(), Types.LONG(), Types.DOUBLE()]) + )) + + self.test_table = self.t_env.from_data_stream( + self.env.from_collection([ + (1, 6) + ], + type_info=Types.ROW_NAMED( + ['user', 'item'], + [Types.LONG(), Types.LONG()]) + )) + + def test_param(self): + als = Als() + self.assertEqual('item', als.item_col) + self.assertEqual('user', als.user_col) + self.assertEqual('rating', als.rating_col) + self.assertEqual(10, als.rank) + self.assertEqual(10, als.max_iter) + self.assertEqual(False, als.non_negative) + self.assertEqual(False, als.implicit_refs) + self.assertEqual(1.0, als.alpha) + self.assertEqual(0.1, als.reg_param) + self.assertEqual(als.prediction_col, 'prediction') + + als.set_item_col('item_1') \ + .set_user_col('user_1') \ + .set_rating_col('rating_1') \ + .set_rank(50) \ + .set_max_iter(30) \ + .set_non_negative(True) \ + .set_implicit_refs(True) \ + .set_alpha(0.35) \ + .set_reg_param(0.25) \ + .set_prediction_col('prediction_col') + + self.assertEqual('item_1', Als.item_col) + self.assertEqual('user_1', Als.user_col) + self.assertEqual('rating_1', als.rating_col) + self.assertEqual(50, als.rank) + self.assertEqual(30, als.max_iter) + self.assertEqual(True, als.non_negative) + self.assertEqual(True, als.implicit_refs) + self.assertEqual(0.35, als.alpha) + self.assertEqual(0.25, als.reg_param) + self.assertEqual(als.prediction_col, 'prediction_col') + + def test_output_schema(self): + als = Als() \ + .set_item_col('test_item') \ + .set_user_col('test_user') \ + .set_rating_col('test_rating') \ + .set_prediction_col('prediction_col') + output = als.transform(self.input_table.alias('test_user', 'test_item', 'test_rating'))[0] + self.assertEqual( + ['test_user', 'test_item', 'test_rating', 'prediction_col'], + output.get_schema().get_field_names()) + + def test_transform(self): + als = Als() + output = als.fit(self.input_table).transform(self.test_table)[0] + self.verify_output_result( + output, + als.get_prediction_col(), + output.get_schema().get_field_names()) + + def test_save_load_and_transform(self): + als = Als() + reloaded_Als = self.save_and_reload(Als) + output = reloaded_Als.fit(self.input_table).transform(self.test_table)[0] + self.verify_output_result( + output, + als.get_prediction_col(), + output.get_schema().get_field_names()) + + def verify_output_result( + self, output: Table, + prediction_col: str, + field_names: List[str]): + collected_results = [result for result in + self.t_env.to_data_stream(output).execute_and_collect()] + for result in collected_results: + prediction = result[field_names.index(prediction_col)] + self.assertEqual(0.37558552399494904, prediction) diff --git a/flink-ml-servable-core/pom.xml b/flink-ml-servable-core/pom.xml index 4fe90d050..e3a3e6b7a 100644 --- a/flink-ml-servable-core/pom.xml +++ b/flink-ml-servable-core/pom.xml @@ -58,6 +58,13 @@ under the License. 2.2.1 + + dev.ludovic.netlib + lapack + 2.2.1 + compile + + org.apache.flink flink-test-utils diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java index ad67301d0..a524ea6b0 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java @@ -85,6 +85,18 @@ public double get(int i, int j) { return values[numRows * j + i]; } + @Override + public double add(int i, int j, double value) { + Preconditions.checkArgument(i >= 0 && i < numRows && j >= 0 && j < numCols); + return values[numRows * j + i] += value; + } + + @Override + public double set(int i, int j, double value) { + Preconditions.checkArgument(i >= 0 && i < numRows && j >= 0 && j < numCols); + return values[numRows * j + i] = value; + } + @Override public DenseMatrix toDense() { return this; diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java index 1664b7708..436cbe041 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java @@ -35,6 +35,12 @@ public interface Matrix extends Serializable { /** Gets value of the (i,j) element. */ double get(int i, int j); + /** Adds value to the (i,j) element. */ + double add(int i, int j, double value); + + /** Sets value of the (i,j) element. */ + double set(int i, int j, double value); + /** Converts the instance to a dense matrix. */ DenseMatrix toDense(); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/NNLS.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/NNLS.java new file mode 100644 index 000000000..e22471420 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/NNLS.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import java.util.Arrays; + +/** + * Solver used to solve nonnegative least squares problems using a modified projected gradient + * method. + */ +public class NNLS { + private static final dev.ludovic.netlib.BLAS NATIVE_BLAS = + dev.ludovic.netlib.JavaBLAS.getInstance(); + + private double[] scratch; + private double[] grad; + private double[] x; + private double[] dir; + private double[] lastDir; + private double[] res; + private int n; + private boolean initialized = false; + + public void initialize(int n) { + if (!initialized) { + this.n = n; + scratch = new double[n]; + grad = new double[n]; + x = new double[n]; + dir = new double[n]; + lastDir = new double[n]; + res = new double[n]; + initialized = true; + } + } + + public void wipe() { + Arrays.fill(scratch, 0.0); + Arrays.fill(grad, 0.0); + Arrays.fill(x, 0.0); + Arrays.fill(dir, 0.0); + Arrays.fill(lastDir, 0.0); + Arrays.fill(res, 0.0); + } + + public double[] solve(double[] ata, double[] atb) { + wipe(); + + int n = atb.length; + + int iterMax = Math.max(400, 20 * n); + double lastNorm = 0.0; + int iterno = 0; + int lastWall = 0; // Last iteration when we hit a bound constraint. + int i; + while (iterno < iterMax) { + // find the residual + NATIVE_BLAS.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1); + NATIVE_BLAS.daxpy(n, -1.0, atb, 1, res, 1); + NATIVE_BLAS.dcopy(n, res, 1, grad, 1); + + // project the gradient + i = 0; + while (i < n) { + if (grad[i] > 0.0 && x[i] == 0.0) { + grad[i] = 0.0; + } + i = i + 1; + } + double ngrad = NATIVE_BLAS.ddot(n, grad, 1, grad, 1); + + NATIVE_BLAS.dcopy(n, grad, 1, dir, 1); + + // use a CG direction under certain conditions + double step = steplen(grad, res, ata); + double ndir; + double nx = NATIVE_BLAS.ddot(n, x, 1, x, 1); + if (iterno > lastWall + 1) { + double alpha = ngrad / lastNorm; + NATIVE_BLAS.daxpy(n, alpha, lastDir, 1, dir, 1); + double dstep = steplen(dir, res, ata); + ndir = NATIVE_BLAS.ddot(n, dir, 1, dir, 1); + if (stop(dstep, ndir, nx)) { + // reject the CG step if it could lead to premature termination + NATIVE_BLAS.dcopy(n, grad, 1, dir, 1); + ndir = NATIVE_BLAS.ddot(n, dir, 1, dir, 1); + } else { + step = dstep; + } + } else { + ndir = NATIVE_BLAS.ddot(n, dir, 1, dir, 1); + } + + // terminate or not. + if (stop(step, ndir, nx)) { + return x.clone(); + } + + // don't run through the walls + i = 0; + while (i < n) { + if (step * dir[i] > x[i]) { + step = x[i] / dir[i]; + } + i = i + 1; + } + + // take the step + i = 0; + while (i < n) { + if (step * dir[i] > x[i] * (1 - 1e-14)) { + x[i] = 0; + lastWall = iterno; + } else { + x[i] -= step * dir[i]; + } + i = i + 1; + } + + iterno = iterno + 1; + NATIVE_BLAS.dcopy(n, dir, 1, lastDir, 1); + lastNorm = ngrad; + } + return x.clone(); + } + + // find the optimal unconstrained step + private double steplen(double[] dir, double[] res, double[] ata) { + double top = NATIVE_BLAS.ddot(n, dir, 1, res, 1); + NATIVE_BLAS.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1); + // Push the denominator upward very slightly to avoid infinities and silliness + return top / (NATIVE_BLAS.ddot(n, scratch, 1, dir, 1) + 1e-20); + } + + // stopping condition + boolean stop(Double step, double ndir, double nx) { + return ((step.isNaN()) // NaN + || (step < 1e-7) // too small or negative + || (step > 1e40) // too small; almost certainly numerical problems + || (ndir < 1e-12 * nx) // gradient relatively too small + || (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk + ); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/NormalEquationSolver.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/NormalEquationSolver.java new file mode 100644 index 000000000..cb1ef2424 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/NormalEquationSolver.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.netlib.util.intW; + +import java.util.Arrays; + +/** A normal equation is A^T * A * x = A^T * b, where A * x = b is a lease square problem. */ +public class NormalEquationSolver { + private static final dev.ludovic.netlib.BLAS NATIVE_BLAS = + dev.ludovic.netlib.JavaBLAS.getInstance(); + public static final dev.ludovic.netlib.LAPACK LAPACK = + dev.ludovic.netlib.lapack.F2jLAPACK.getInstance(); + /** Rank of the equation. */ + private final int n; + + /** A^T * A. */ + private final DenseMatrix ata; + + /** A^T * b. */ + private final DenseVector atb; + + NNLS nnls; + + public NormalEquationSolver(int n, DenseMatrix ata, DenseVector atb) { + this.n = n; + this.ata = ata; + this.atb = atb; + nnls = new NNLS(); + } + + public DenseMatrix getAta() { + return ata; + } + + public DenseVector getAtb() { + return atb; + } + + /** + * The constructor. + * + * @param n Rank of the equation. + */ + public NormalEquationSolver(int n) { + this.n = n; + this.ata = new DenseMatrix(n, n); + this.atb = new DenseVector(n); + nnls = new NNLS(); + } + + /** + * Add coefficients to the normal equation. + * + * @param a A row of matrix "A". + * @param b An element of right hand side "b". + * @param c The scale factor of "a". + */ + public void add(DenseVector a, double b, double c) { + // ata += c * a.t * a + NATIVE_BLAS.dger(n, n, c, a.values, 1, a.values, 1, this.ata.values, n); + + // atb += b * a.t + BLAS.axpy(b, a, this.atb); + } + + /** Reset the system to zeros. */ + public void reset() { + Arrays.fill(ata.values, 0.); + Arrays.fill(atb.values, 0.); + } + + /** Merge with another A^T*A. */ + public void merge(DenseMatrix otherAta) { + NATIVE_BLAS.daxpy(ata.values.length, 1.0, otherAta.values, 1, ata.values, 1); + } + + /** Merge with another NormalEquation. */ + public void merge(NormalEquationSolver otherEq) { + merge(otherEq.ata); + BLAS.axpy(1.0, otherEq.atb, this.atb); + } + + /** Regularize the system by adding "lambda" to diagonals. */ + public void regularize(double lambda) { + for (int i = 0; i < n; i++) { + this.ata.add(i, i, lambda); + } + } + + /** + * Solve the system. After solving the system, the result is returned in x, and the + * data in ata and atb will be reset to zeros. + * + * @param x For holding the result. + * @param nonNegative Whether to enforce non-negative constraint. + */ + public void solve(DenseVector x, boolean nonNegative) { + if (nonNegative) { + nnls.initialize(n); + double[] ret = nnls.solve(ata.values, atb.values); + System.arraycopy(ret, 0, x.values, 0, n); + } else { + int n = ata.numCols(); + int nrhs = 1; + intW info = new intW(0); + LAPACK.dposv("U", n, nrhs, ata.values, n, atb.values, n, info); + // check solution + if (info.val > 0) { + throw new RuntimeException("A is not positive definite."); + } else if (info.val < 0) { + throw new RuntimeException("Invalid input to lapack routine."); + } + System.arraycopy(atb.values, 0, x.values, 0, n); + } + this.reset(); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java index 8de3a44d4..3beae41ca 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java @@ -44,6 +44,13 @@ public static double getDouble(byte[] b, int off) { return Double.longBitsToDouble(getLong(b, off)); } + public static int getInt(byte[] b, int off) { + return ((b[off + 3] & 0xFF)) + + ((b[off + 2] & 0xFF) << 8) + + ((b[off + 1] & 0xFF) << 16) + + ((b[off]) << 24); + } + /* * Methods for packing primitive values into byte arrays starting at given * offsets. @@ -63,4 +70,75 @@ public static void putLong(byte[] b, int off, long val) { public static void putDouble(byte[] b, int off, double val) { putLong(b, off, Double.doubleToLongBits(val)); } + + public static void putInt(byte[] b, int off, int val) { + b[off + 3] = (byte) (val); + b[off + 2] = (byte) (val >>> 8); + b[off + 1] = (byte) (val >>> 16); + b[off] = (byte) (val >>> 24); + } + + /** Gets a long array from the byte array starting from the given offset. */ + public static long[] getLongArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getLong(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a long array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putLongArray(long[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putLong(bytes, offset, array[i]); + offset += Long.BYTES; + } + return offset; + } + + /** Returns the size of a long array in bytes. */ + public static int getLongArraySizeInBytes(long[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Gets a double array from the byte array starting from the given offset. */ + public static double[] getDoubleArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getDouble(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putDoubleArray(double[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putDouble(bytes, offset, array[i]); + offset += Double.BYTES; + } + return offset; + } + + /** Returns the size of a double array in bytes. */ + public static int getDoubleArraySizeInBytes(double[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } } diff --git a/flink-ml-uber/pom.xml b/flink-ml-uber/pom.xml index 212527d48..1d6629c9d 100644 --- a/flink-ml-uber/pom.xml +++ b/flink-ml-uber/pom.xml @@ -94,6 +94,7 @@ under the License. org.apache.flink:flink-ml-lib-${flink.main.version} org.apache.flink:flink-ml-benchmark-${flink.main.version} dev.ludovic.netlib:blas + it.unimi.dsi:fastutil