diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsOneInputStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsOneInputStreamOperator.java new file mode 100644 index 000000000..9d0ccbfdc --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsOneInputStreamOperator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; + +import java.util.List; + +/** The base class for {@link OneInputStreamOperator}s where shared objects are accessed. */ +public abstract class AbstractSharedObjectsOneInputStreamOperator + extends AbstractSharedObjectsStreamOperator + implements OneInputStreamOperator { + + public abstract List> readRequestsInProcessElement(); +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java new file mode 100644 index 000000000..edc5a530f --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; + +import java.util.UUID; + +/** + * A base class of stream operators where shared objects are required. + * + *

Official subclasses, i.e., {@link AbstractSharedObjectsOneInputStreamOperator} and {@link + * AbstractSharedObjectsTwoInputStreamOperator}, are strongly recommended. + * + *

If you are going to implement a subclass by yourself, you have to handle potential deadlocks. + */ +public abstract class AbstractSharedObjectsStreamOperator extends AbstractStreamOperator { + + /** + * A unique identifier for the instance, which is kept unchanged between client side and + * runtime. + */ + private final String accessorID; + + /** The context for shared objects reads/writes. */ + protected transient SharedObjectsContext context; + + AbstractSharedObjectsStreamOperator() { + super(); + accessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + } + + void onSharedObjectsContextSet(SharedObjectsContext context) { + this.context = context; + } + + String getAccessorID() { + return accessorID; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsTwoInputStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsTwoInputStreamOperator.java new file mode 100644 index 000000000..ee7b59dc5 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsTwoInputStreamOperator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; + +import java.util.List; + +/** The base class for {@link TwoInputStreamOperator}s where shared objects are accessed. */ +public abstract class AbstractSharedObjectsTwoInputStreamOperator + extends AbstractSharedObjectsStreamOperator + implements TwoInputStreamOperator { + + public abstract List> readRequestsInProcessElement1(); + + public abstract List> readRequestsInProcessElement2(); +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java new file mode 100644 index 000000000..bed1f23df --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.ManagedMemoryUseCase; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext; +import org.apache.flink.metrics.groups.OperatorMetricGroup; +import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement; +import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementSerializer; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.metrics.groups.InternalOperatorIOMetricGroup; +import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; +import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.StreamOperatorStateContext; +import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler; +import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler.CheckpointedStreamOperator; +import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.ThrowingConsumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Queue; + +/** Base class for the shared objects wrapper operators. */ +abstract class AbstractSharedObjectsWrapperOperator< + T, S extends AbstractSharedObjectsStreamOperator> + implements StreamOperator, IterationListener, CheckpointedStreamOperator { + + private static final Logger LOG = + LoggerFactory.getLogger(AbstractSharedObjectsWrapperOperator.class); + + protected final StreamOperatorParameters parameters; + + protected final StreamConfig streamConfig; + + protected final StreamTask containingTask; + + protected final Output> output; + + protected final StreamOperatorFactory operatorFactory; + + protected final OperatorMetricGroup metrics; + + protected final S wrappedOperator; + + private final SharedObjectsContextImpl context; + private final int numInputs; + private final TypeSerializer[] inTypeSerializers; + private final ListStateWithCache>[] cachedElements; + private final Queue>[] readRequests; + private final boolean[] hasCachedElements; + + protected transient StreamOperatorStateHandler stateHandler; + + protected transient InternalTimeServiceManager timeServiceManager; + + @SuppressWarnings({"unchecked", "rawtypes"}) + AbstractSharedObjectsWrapperOperator( + StreamOperatorParameters parameters, + StreamOperatorFactory operatorFactory, + SharedObjectsContextImpl context) { + this.parameters = Objects.requireNonNull(parameters); + this.streamConfig = Objects.requireNonNull(parameters.getStreamConfig()); + this.containingTask = Objects.requireNonNull(parameters.getContainingTask()); + this.output = Objects.requireNonNull(parameters.getOutput()); + this.operatorFactory = Objects.requireNonNull(operatorFactory); + this.context = context; + this.metrics = createOperatorMetricGroup(containingTask.getEnvironment(), streamConfig); + this.wrappedOperator = + (S) + StreamOperatorFactoryUtil.createOperator( + operatorFactory, + (StreamTask) containingTask, + streamConfig, + output, + parameters.getOperatorEventDispatcher()) + .f0; + wrappedOperator.onSharedObjectsContextSet(context); + + StreamConfig.InputConfig[] inputConfigs = + streamConfig.getInputs(containingTask.getUserCodeClassLoader()); + int numNetworkInputs = 0; + while (numNetworkInputs < inputConfigs.length + && inputConfigs[numNetworkInputs] instanceof StreamConfig.NetworkInputConfig) { + numNetworkInputs++; + } + numInputs = numNetworkInputs; + + inTypeSerializers = new TypeSerializer[numInputs]; + readRequests = new Queue[numInputs]; + for (int i = 0; i < numInputs; i++) { + inTypeSerializers[i] = + streamConfig.getTypeSerializerIn(i, containingTask.getUserCodeClassLoader()); + readRequests[i] = new ArrayDeque<>(getInputReadRequests(i)); + } + cachedElements = new ListStateWithCache[numInputs]; + hasCachedElements = new boolean[numInputs]; + Arrays.fill(hasCachedElements, false); + } + + private OperatorMetricGroup createOperatorMetricGroup( + Environment environment, StreamConfig streamConfig) { + try { + OperatorMetricGroup operatorMetricGroup = + environment + .getMetricGroup() + .getOrAddOperator( + streamConfig.getOperatorID(), streamConfig.getOperatorName()); + if (streamConfig.isChainEnd()) { + ((InternalOperatorIOMetricGroup) operatorMetricGroup.getIOMetricGroup()) + .reuseOutputMetricsForTask(); + } + return operatorMetricGroup; + } catch (Exception e) { + LOG.warn("An error occurred while instantiating task metrics.", e); + return UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup(); + } + } + + /** + * Checks if the read requests are satisfied for the input. + * + * @param inputId The input id, starting from 0. + * @param wait Whether to wait until all requests satisfied, or not. + * @return If all requests of this input are satisfied. + */ + private boolean checkReadRequestsReady(int inputId, boolean wait) { + Queue> requests = readRequests[inputId]; + while (!requests.isEmpty()) { + ReadRequest request = requests.poll(); + try { + if (null == context.read(request, wait)) { + requests.add(request); + return false; + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + return true; + } + + /** + * Gets {@link ReadRequest}s required for processing elements in the input. + * + * @param inputId The input id, starting from 0. + * @return The {@link ReadRequest}s required for processing elements. + */ + protected abstract List> getInputReadRequests(int inputId); + + /** + * Extracts common processing logic in subclasses' processing elements. + * + * @param streamRecord The input record. + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings({"rawtypes"}) + protected void processElementX( + StreamRecord streamRecord, + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + if (checkReadRequestsReady(inputId, false)) { + if (hasCachedElements[inputId]) { + processCachedElements( + inputId, elementConsumer, watermarkConsumer, keyContextSetter); + hasCachedElements[inputId] = false; + } + keyContextSetter.accept(streamRecord); + elementConsumer.accept(streamRecord); + } else { + cachedElements[inputId].add(CacheElement.newRecord(streamRecord.getValue())); + hasCachedElements[inputId] = true; + } + } + + /** + * Extracts common processing logic in subclasses' processing watermarks. + * + * @param watermark The input watermark. + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings({"rawtypes"}) + protected void processWatermarkX( + Watermark watermark, + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + if (checkReadRequestsReady(inputId, false)) { + if (hasCachedElements[inputId]) { + processCachedElements( + inputId, elementConsumer, watermarkConsumer, keyContextSetter); + hasCachedElements[inputId] = false; + } + watermarkConsumer.accept(watermark); + } else { + cachedElements[inputId].add(CacheElement.newWatermark(watermark.getTimestamp())); + hasCachedElements[inputId] = true; + } + } + + /** + * Extracts common processing logic in subclasses' endInput(...). + * + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings("rawtypes") + protected void endInputX( + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + if (hasCachedElements[inputId]) { + checkReadRequestsReady(inputId, true); + processCachedElements(inputId, elementConsumer, watermarkConsumer, keyContextSetter); + hasCachedElements[inputId] = false; + } + } + + /** + * Processes elements that are cached by {@link ListStateWithCache}. + * + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + private void processCachedElements( + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + for (CacheElement cacheElement : cachedElements[inputId].get()) { + switch (cacheElement.getType()) { + case RECORD: + StreamRecord record = new StreamRecord(cacheElement.getRecord()); + keyContextSetter.accept(record); + elementConsumer.accept(record); + break; + case WATERMARK: + watermarkConsumer.accept(new Watermark(cacheElement.getWatermark())); + break; + default: + throw new RuntimeException( + "Unsupported CacheElement type: " + cacheElement.getType()); + } + } + cachedElements[inputId].clear(); + Preconditions.checkState(readRequests[inputId].isEmpty()); + readRequests[inputId].addAll(getInputReadRequests(inputId)); + } + + @Override + public void open() throws Exception { + wrappedOperator.open(); + } + + @Override + public void close() throws Exception { + wrappedOperator.close(); + context.clear(); + } + + @Override + public void finish() throws Exception { + wrappedOperator.finish(); + } + + @Override + public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { + wrappedOperator.prepareSnapshotPreBarrier(checkpointId); + } + + @Override + @SuppressWarnings("unchecked, rawtypes") + public void initializeState(StateInitializationContext stateInitializationContext) + throws Exception { + StreamingRuntimeContext runtimeContext = wrappedOperator.getRuntimeContext(); + context.initializeState(wrappedOperator, runtimeContext, stateInitializationContext); + for (int i = 0; i < numInputs; i++) { + cachedElements[i] = + new ListStateWithCache<>( + new CacheElementSerializer(inTypeSerializers[i]), + containingTask, + runtimeContext, + stateInitializationContext, + streamConfig.getOperatorID()); + } + } + + @Override + public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception { + context.snapshotState(stateSnapshotContext); + wrappedOperator.snapshotState(stateSnapshotContext); + for (int i = 0; i < numInputs; i++) { + cachedElements[i].snapshotState(stateSnapshotContext); + } + } + + @Override + public OperatorSnapshotFutures snapshotState( + long checkpointId, + long timestamp, + CheckpointOptions checkpointOptions, + CheckpointStreamFactory storageLocation) + throws Exception { + return stateHandler.snapshotState( + this, + Optional.ofNullable(timeServiceManager), + streamConfig.getOperatorName(), + checkpointId, + timestamp, + checkpointOptions, + storageLocation, + false); + } + + @Override + public void initializeState(StreamTaskStateInitializer streamTaskStateManager) + throws Exception { + final TypeSerializer keySerializer = + streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader()); + + StreamOperatorStateContext streamOperatorStateContext = + streamTaskStateManager.streamOperatorStateContext( + getOperatorID(), + getClass().getSimpleName(), + parameters.getProcessingTimeService(), + this, + keySerializer, + containingTask.getCancelables(), + metrics, + streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot( + ManagedMemoryUseCase.STATE_BACKEND, + containingTask + .getEnvironment() + .getTaskManagerInfo() + .getConfiguration(), + containingTask.getUserCodeClassLoader()), + false); + stateHandler = + new StreamOperatorStateHandler( + streamOperatorStateContext, + containingTask.getExecutionConfig(), + containingTask.getCancelables()); + stateHandler.initializeOperatorState(this); + + timeServiceManager = streamOperatorStateContext.internalTimerServiceManager(); + + wrappedOperator.initializeState( + (operatorID, + operatorClassName, + processingTimeService, + keyContext, + keySerializerX, + streamTaskCloseableRegistry, + metricGroup, + managedMemoryFraction, + isUsingCustomRawKeyedState) -> + new ProxyStreamOperatorStateContext( + streamOperatorStateContext, + "wrapped-", + CloseableIterator.empty(), + 0)); + } + + @Override + public void setKeyContextElement1(StreamRecord record) throws Exception { + wrappedOperator.setKeyContextElement1(record); + } + + @Override + public void setKeyContextElement2(StreamRecord record) throws Exception { + wrappedOperator.setKeyContextElement2(record); + } + + @Override + public OperatorMetricGroup getMetricGroup() { + return wrappedOperator.getMetricGroup(); + } + + @Override + public OperatorID getOperatorID() { + return wrappedOperator.getOperatorID(); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + wrappedOperator.notifyCheckpointComplete(checkpointId); + } + + @Override + public void notifyCheckpointAborted(long checkpointId) throws Exception { + wrappedOperator.notifyCheckpointAborted(checkpointId); + } + + @Override + public Object getCurrentKey() { + return wrappedOperator.getCurrentKey(); + } + + @Override + public void setCurrentKey(Object key) { + wrappedOperator.setCurrentKey(key); + } + + protected abstract void processCachedElementsBeforeEpochIncremented(int inputId) + throws Exception; + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + for (int i = 0; i < numInputs; i += 1) { + processCachedElementsBeforeEpochIncremented(i); + } + this.context.incStep(epochWatermark); + if (wrappedOperator instanceof IterationListener) { + //noinspection unchecked + ((IterationListener) wrappedOperator) + .onEpochWatermarkIncremented(epochWatermark, context, collector); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) throws Exception { + this.context.incStep(); + if (wrappedOperator instanceof IterationListener) { + //noinspection unchecked + ((IterationListener) wrappedOperator).onIterationTerminated(context, collector); + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java new file mode 100644 index 000000000..fd61957b2 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.io.Serializable; + +/** + * Descriptor for a shared object. + * + *

A shared object can have a non-null initial value, or have no initial values. If a non-null + * initial value provided, it is set with an initial write-step (See {@link ReadRequest}). + * + * @param The type of the shared object. + */ +@Experimental +public class Descriptor implements Serializable { + + /** Name of the shared object. */ + public final String name; + + /** Type serializer. */ + public final TypeSerializer serializer; + + /** Initialize value. */ + public final @Nullable T initVal; + + private Descriptor(String name, TypeSerializer serializer, T initVal) { + this.name = name; + this.serializer = serializer; + this.initVal = initVal; + } + + public static Descriptor of(String name, TypeSerializer serializer, T initVal) { + Preconditions.checkNotNull( + initVal, "Cannot use `null` as the initial value of a shared object."); + return new Descriptor<>(name, serializer, initVal); + } + + public static Descriptor of(String name, TypeSerializer serializer) { + return new Descriptor<>(name, serializer, null); + } + + /** + * Creates a read request which always reads this shared object with same read-step as the + * operator step. + * + * @return A read request. + */ + public ReadRequest sameStep() { + return new ReadRequest<>(this, ReadRequest.OFFSET.SAME); + } + + /** + * Creates a read request which always reads this shared object with the read-step be the + * previous item of the operator step. + * + * @return A read request. + */ + public ReadRequest prevStep() { + return new ReadRequest<>(this, ReadRequest.OFFSET.PREV); + } + + /** + * Creates a read request which always reads this shared object with the read-step be the next + * item of the operator step. + * + * @return A read request. + */ + public ReadRequest nextStep() { + return new ReadRequest<>(this, ReadRequest.OFFSET.NEXT); + } + + @Override + public int hashCode() { + return name.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return name.equals(that.name); + } + + @Override + public String toString() { + return String.format( + "Descriptor{name='%s', serializer=%s, initVal=%s}", name, serializer, initVal); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java new file mode 100644 index 000000000..68f3b3687 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.Preconditions; + +import java.util.List; + +/** Wrapper for {@link OneInputStreamOperator}. */ +class OneInputSharedObjectsWrapperOperator + extends AbstractSharedObjectsWrapperOperator< + OUT, AbstractSharedObjectsOneInputStreamOperator> + implements OneInputStreamOperator, BoundedOneInput { + + OneInputSharedObjectsWrapperOperator( + StreamOperatorParameters parameters, + StreamOperatorFactory operatorFactory, + SharedObjectsContextImpl context) { + super(parameters, operatorFactory, context); + } + + @Override + protected List> getInputReadRequests(int inputId) { + Preconditions.checkArgument(0 == inputId); + return wrappedOperator.readRequestsInProcessElement(); + } + + @Override + protected void processCachedElementsBeforeEpochIncremented(int inputId) throws Exception { + Preconditions.checkArgument(0 == inputId); + endInputX( + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); + } + + @Override + public void processElement(StreamRecord streamRecord) throws Exception { + processElementX( + streamRecord, + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); + } + + @Override + public void endInput() throws Exception { + endInputX( + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); + OperatorUtils.processOperatorOrUdfIfSatisfy( + wrappedOperator, BoundedOneInput.class, BoundedOneInput::endInput); + } + + @Override + public void processWatermark(Watermark watermark) throws Exception { + processWatermarkX( + watermark, + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); + } + + @Override + public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws Exception { + wrappedOperator.processWatermarkStatus(watermarkStatus); + } + + @Override + public void processLatencyMarker(LatencyMarker latencyMarker) throws Exception { + wrappedOperator.processLatencyMarker(latencyMarker); + } + + @Override + public void setKeyContextElement(StreamRecord streamRecord) throws Exception { + wrappedOperator.setKeyContextElement(streamRecord); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java new file mode 100644 index 000000000..c4f6178de --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.iteration.IterationListener; + +import java.io.Serializable; + +/** + * A read request for a shared object with given step offset. The step {@link OFFSET} is used to + * calculate read-step from current operator step. + * + *

The concept of `step` is first defined on operators. Every operator maintains its `step` + * implicitly. For operators in non-iterations usage, their `step`s are treated as constants. While + * for operators in iterations usage, their `step`s are bound to the epoch watermarks: + * + *

With every call of {@link IterationListener#onEpochWatermarkIncremented}, the value of step is + * set to the epoch watermark. Before the first call of {@link + * IterationListener#onEpochWatermarkIncremented}, the step is set to a small enough value. While + * after {@link IterationListener#onIterationTerminated}, the step is set to a large enough value. + * In this way, the changes of step can be considered as an ordered sequence. Note that, the `step` + * is implicitly maintained by the infrastructure, even if the operator is not implementing {@link + * IterationListener}. + * + *

Then, the concept of `step` is defined on reads and writes of shared objects. Every write + * brings the step of its owner operator at that moment, which is named as `write-step`. To read the + * shared object with the exact `write-step`, the reader operator must provide a same `read-step`. + * The `read-step` could be different from that of the reader operator, and their difference is kept + * unchanged, which is the step offset defined in {@link ReadRequest#offset}. + * + * @param The type of the shared object. + */ +@Experimental +public class ReadRequest implements Serializable { + final Descriptor descriptor; + final OFFSET offset; + + ReadRequest(Descriptor descriptor, OFFSET offset) { + this.descriptor = descriptor; + this.offset = offset; + } + + enum OFFSET { + SAME, + PREV, + NEXT, + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java new file mode 100644 index 000000000..9f938a233 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.datastream.DataStream; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +/** + * The builder of the subgraph that will be executed with a common shared objects. Users can only + * create data streams from {@code inputs}. Users can not refer to data streams outside, and can not + * add sources/sinks. + * + *

The shared objects body requires all transformations accessing the shared objects, i.e., + * {@link SharedObjectsBodyResult#coLocatedTransformations}, to have same parallelism and can be + * co-located. + */ +@Experimental +@FunctionalInterface +public interface SharedObjectsBody extends Serializable { + + /** + * This method creates the subgraph for the shared objects body. + * + * @param inputs Input data streams. + * @return Result of the subgraph, including output data streams, data streams with access to + * the shared objects, and a mapping from share objects to their owners. + */ + SharedObjectsBodyResult process(List> inputs); + + /** + * The result of a {@link SharedObjectsBody}, including output data streams, data streams with + * access to the shared objects, and a mapping from descriptors of share objects to their + * owners. + */ + @Experimental + class SharedObjectsBodyResult { + /** A list of output streams. */ + private final List> outputs; + + /** + * A list of {@link Transformation}s that should be co-located, which should include all + * subclasses of {@link AbstractSharedObjectsStreamOperator}. + */ + private final List> coLocatedTransformations; + + /** A mapping from descriptors of shared objects to their owner operators. */ + private final Map, AbstractSharedObjectsStreamOperator> ownerMap; + + public SharedObjectsBodyResult( + List> outputs, + List> coLocatedTransformations, + Map, AbstractSharedObjectsStreamOperator> ownerMap) { + this.outputs = outputs; + this.coLocatedTransformations = coLocatedTransformations; + this.ownerMap = ownerMap; + } + + public List> getOutputs() { + return outputs; + } + + public List> getCoLocatedTransformations() { + return coLocatedTransformations; + } + + public Map, AbstractSharedObjectsStreamOperator> getOwnerMap() { + return ownerMap; + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java new file mode 100644 index 000000000..7ec0acf8c --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; + +/** + * Context for shared objects. Every operator implementing {@link + * AbstractSharedObjectsStreamOperator} will get an instance of this context set by {@link + * AbstractSharedObjectsStreamOperator#onSharedObjectsContextSet} in runtime. + * + *

See {@link ReadRequest} for details about coordination between reads and writes. + */ +@Experimental +public interface SharedObjectsContext { + + /** + * Reads the value of a shared object. + * + *

For subclasses of {@link AbstractSharedObjectsOneInputStreamOperator} and {@link + * AbstractSharedObjectsTwoInputStreamOperator}, this method is guaranteed to return non-null + * values immediately. + * + * @param request A read request of a shared object. + * @return The value of the shared object. + * @param The type of the shared object. + */ + T read(ReadRequest request); + + /** + * Writes a new value to the shared object. + * + * @param descriptor The shared object descriptor. + * @param value The value to be set. + * @param The type of the shared object. + */ + void write(Descriptor descriptor, T value); + + /** + * Renew the shared object with current step. + * + *

For subclasses of {@link AbstractSharedObjectsOneInputStreamOperator} and {@link + * AbstractSharedObjectsTwoInputStreamOperator}, this method is guaranteed to return + * immediately. + * + * @param descriptor The shared object descriptor. + * @param The type of the shared object. + */ + void renew(Descriptor descriptor); +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java new file mode 100644 index 000000000..9819baf82 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.ml.common.sharedobjects.SharedObjectsPools.getReader; +import static org.apache.flink.ml.common.sharedobjects.SharedObjectsPools.getWriter; + +/** + * A default implementation of {@link SharedObjectsContext}. + * + *

It initializes readers and writers of shared objects according to the owner map when the + * subtask starts and clean internal states when the subtask finishes. It also handles + * `initializeState` and `snapshotState` automatically. + */ +@SuppressWarnings("rawtypes") +class SharedObjectsContextImpl implements SharedObjectsContext, Serializable { + private final SharedObjectsPools.PoolID poolID; + private final Map writers = new HashMap<>(); + private final Map readers = new HashMap<>(); + private Map, String> ownerMap; + + /** The step of corresponding operator. See {@link ReadRequest} for more information. */ + private int step; + + public SharedObjectsContextImpl() { + this.poolID = new SharedObjectsPools.PoolID(); + step = -1; + } + + void setOwnerMap(Map, String> ownerMap) { + this.ownerMap = ownerMap; + } + + void incStep(@Nullable Integer targetStep) { + step += 1; + // Sanity check + Preconditions.checkState(null == targetStep || step == targetStep); + } + + void incStep() { + incStep(null); + } + + void initializeState( + StreamOperator operator, + StreamingRuntimeContext runtimeContext, + StateInitializationContext context) { + Preconditions.checkArgument(operator instanceof AbstractSharedObjectsStreamOperator); + String ownerId = ((AbstractSharedObjectsStreamOperator) operator).getAccessorID(); + int subtaskId = runtimeContext.getIndexOfThisSubtask(); + for (Map.Entry, String> entry : ownerMap.entrySet()) { + Descriptor descriptor = entry.getKey(); + if (ownerId.equals(entry.getValue())) { + writers.put( + descriptor, + getWriter( + poolID, + subtaskId, + descriptor, + ownerId, + operator.getOperatorID(), + ((AbstractStreamOperator) operator).getContainingTask(), + runtimeContext, + context, + step)); + } + readers.put(descriptor, getReader(poolID, subtaskId, descriptor)); + } + } + + void snapshotState(StateSnapshotContext context) throws Exception { + for (SharedObjectsPools.Writer writer : writers.values()) { + writer.snapshotState(context); + } + } + + void clear() { + for (SharedObjectsPools.Writer writer : writers.values()) { + writer.remove(); + } + for (SharedObjectsPools.Reader reader : readers.values()) { + reader.remove(); + } + writers.clear(); + readers.clear(); + } + + @Override + public T read(ReadRequest request) { + try { + return read(request, false); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + /** + * Gets the value of the shared object with possible waiting. + * + * @param request A read request of a shared object. + * @param wait Whether to wait or not. + * @return The value of the shared object, or null if not set yet. + * @param The type of the shared object. + */ + T read(ReadRequest request, boolean wait) throws InterruptedException { + Descriptor descriptor = request.descriptor; + //noinspection unchecked + SharedObjectsPools.Reader reader = readers.get(descriptor); + switch (request.offset) { + case SAME: + return reader.get(step, wait); + case PREV: + return reader.get(step - 1, wait); + case NEXT: + return reader.get(step + 1, wait); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public void write(Descriptor descriptor, T value) { + //noinspection unchecked + SharedObjectsPools.Writer writer = writers.get(descriptor); + Preconditions.checkState( + null != writer, + String.format( + "The operator requestes to write a shared object %s not owned by itself.", + descriptor)); + writer.set(value, step); + } + + @Override + public void renew(Descriptor descriptor) { + try { + //noinspection unchecked + write( + descriptor, + ((SharedObjectsPools.Reader) readers.get(descriptor)).get(step - 1, false)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java new file mode 100644 index 000000000..25f64151c --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.AbstractID; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; + +/** + * Stores all shared objects and coordinates their reads and writes. + * + *

Every shared object is identified by a tuple of (Pool ID, subtask ID, name). Their reads and + * writes are coordinated through the read- and write-steps. + */ +class SharedObjectsPools { + + private static final Logger LOG = LoggerFactory.getLogger(SharedObjectsPools.class); + + /** Stores values and corresponding write-steps of all shared objects. */ + private static final Map, Tuple2> values = + new ConcurrentHashMap<>(); + + /** + * Stores waiting read requests of all shared objects, including read-steps and count-down + * latches for notification when shared objects are ready. + */ + private static final Map, List>> + waitQueues = new ConcurrentHashMap<>(); + + /** + * Stores owners of all shared objects, where the owner is identified by the accessor ID + * obtained from {@link AbstractSharedObjectsStreamOperator#getAccessorID()}. + */ + private static final Map, String> owners = + new ConcurrentHashMap<>(); + + /** + * Stores number of references of all shared objects. Every {@link Reader} and {@link Writer} + * counts. A shared object is removed from the pool when its number of references decreased to + * 0. + */ + private static final ConcurrentHashMap, Integer> numRefs = + new ConcurrentHashMap<>(); + + private static void incRef(Tuple3 objId) { + numRefs.compute(objId, (k, oldV) -> null == oldV ? 1 : oldV + 1); + } + + private static void decRef(Tuple3 objId) { + int num = numRefs.compute(objId, (k, oldV) -> oldV - 1); + if (num == 0) { + values.remove(objId); + waitQueues.remove(objId); + owners.remove(objId); + numRefs.remove(objId); + } + } + + /** Gets a {@link Reader} of a shared object. */ + static Reader getReader(PoolID poolID, int subtaskId, Descriptor descriptor) { + Tuple3 objId = Tuple3.of(poolID, subtaskId, descriptor.name); + Reader reader = new Reader<>(objId); + incRef(objId); + return reader; + } + + /** Gets a {@link Writer} of a shared object. */ + static Writer getWriter( + PoolID poolId, + int subtaskId, + Descriptor descriptor, + String ownerId, + OperatorID operatorID, + StreamTask containingTask, + StreamingRuntimeContext runtimeContext, + StateInitializationContext stateInitializationContext, + int step) { + Tuple3 objId = Tuple3.of(poolId, subtaskId, descriptor.name); + String lastOwner = owners.putIfAbsent(objId, ownerId); + if (null != lastOwner) { + throw new IllegalStateException( + String.format( + "The shared object (%s, %s, %s) already has a writer %s.", + poolId, subtaskId, descriptor.name, ownerId)); + } + Writer writer = + new Writer<>( + objId, + ownerId, + descriptor.serializer, + containingTask, + runtimeContext, + stateInitializationContext, + operatorID); + incRef(objId); + if (null != descriptor.initVal) { + writer.set(descriptor.initVal, step); + } + return writer; + } + + /** + * Reader of a shared object. + * + * @param The type of the shared object. + */ + static class Reader { + protected final Tuple3 objId; + + Reader(Tuple3 objId) { + this.objId = objId; + } + + /** + * Gets the value with given read-step. There are 3 cases: + * + *

    + *
  1. The read-step is equal to the write-step: returns the value immediately. + *
  2. The read-step is larger than the write-step, or there is no values written yet: + * waits until the value with same write-step set if `wait` is true, or returns null + * otherwise. + *
  3. The read-step is smaller than the write-step: throws an exception as it is illegal. + *
+ * + * @param readStep The read-step. + * @param wait Whether to wait until the value with same write-step presents. + * @return The value or null. A return value of null means the corresponding value if not + * presented. If `wait` is true, the return value of this function is guaranteed to be a + * non-null value if it returns. + * @throws InterruptedException Interrupted when waiting. + */ + T get(int readStep, boolean wait) throws InterruptedException { + //noinspection unchecked + Tuple2 stepV = (Tuple2) values.get(objId); + if (null != stepV) { + int writeStep = stepV.f0; + LOG.debug("Get {} with read-step {}, write-step is {}", objId, readStep, writeStep); + Preconditions.checkState( + writeStep <= readStep, + String.format( + "Current write-step %d of %s is larger than read-step %d, which is illegal.", + writeStep, objId, readStep)); + if (readStep == stepV.f0) { + return stepV.f1; + } + } + if (!wait) { + return null; + } + CountDownLatch latch = new CountDownLatch(1); + synchronized (waitQueues) { + if (!waitQueues.containsKey(objId)) { + waitQueues.put(objId, new ArrayList<>()); + } + List> q = waitQueues.get(objId); + q.add(Tuple2.of(readStep, latch)); + } + latch.await(); + //noinspection unchecked + stepV = (Tuple2) values.get(objId); + Preconditions.checkState(stepV.f0 == readStep); + return stepV.f1; + } + + void remove() { + decRef(objId); + } + } + + /** + * Writer of a shared object. + * + * @param The type of the shared object. + */ + static class Writer extends Reader { + private final String ownerId; + private final ListStateWithCache> cache; + private boolean isDirty; + + Writer( + Tuple3 objId, + String ownerId, + TypeSerializer serializer, + StreamTask containingTask, + StreamingRuntimeContext runtimeContext, + StateInitializationContext stateInitializationContext, + OperatorID operatorID) { + super(objId); + this.ownerId = ownerId; + try { + //noinspection unchecked + cache = + new ListStateWithCache<>( + new TupleSerializer<>( + (Class>) (Class) Tuple2.class, + new TypeSerializer[] {IntSerializer.INSTANCE, serializer}), + containingTask, + runtimeContext, + stateInitializationContext, + operatorID); + Iterator> iterator = cache.get().iterator(); + if (iterator.hasNext()) { + Tuple2 stepV = iterator.next(); + ensureOwner(); + //noinspection unchecked + values.put(objId, (Tuple2) stepV); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + isDirty = false; + } + + private void ensureOwner() { + Preconditions.checkState(owners.get(objId).equals(ownerId)); + } + + /** + * Sets the value with given write-step. If there are read requests waiting for the value of + * exact the same write-step, they are notified. + * + * @param value The value. + * @param writeStep The write-step. + */ + void set(T value, int writeStep) { + ensureOwner(); + values.put(objId, Tuple2.of(writeStep, value)); + LOG.debug("Set {} with write-step {}", objId, writeStep); + isDirty = true; + synchronized (waitQueues) { + if (!waitQueues.containsKey(objId)) { + waitQueues.put(objId, new ArrayList<>()); + } + List> q = waitQueues.get(objId); + ListIterator> iter = q.listIterator(); + while (iter.hasNext()) { + Tuple2 next = iter.next(); + if (writeStep == next.f0) { + iter.remove(); + next.f1.countDown(); + } + } + } + } + + @Override + void remove() { + ensureOwner(); + super.remove(); + cache.clear(); + } + + void snapshotState(StateSnapshotContext context) throws Exception { + if (isDirty) { + //noinspection unchecked + cache.update(Collections.singletonList((Tuple2) values.get(objId))); + isDirty = false; + } + cache.snapshotState(context); + } + } + + /** ID of a pool for shared objects. */ + static class PoolID extends AbstractID { + private static final long serialVersionUID = 1L; + + public PoolID(byte[] bytes) { + super(bytes); + } + + public PoolID() {} + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java new file mode 100644 index 000000000..2c884006a --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.compile.DraftExecutionEnvironment; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +/** Utility class to support shared objects mechanism in DataStream. */ +@Experimental +public class SharedObjectsUtils { + + /** + * Supports read/write access of data in the shared objects from operators which inherit {@link + * AbstractSharedObjectsStreamOperator}. + * + *

In the shared objects `body`, users build the subgraph with data streams only from + * `inputs`, return streams that have access to the shared objects, and return the mapping from + * shared objects to their owners. + * + *

There are several limitations to use this function: + * + *

    + *
  1. Only synchronized iterations and non-iterations are supported. + *
  2. Reads and writes of shared objects must obey strict rules defined on `step`s, as stated + * in {@link ReadRequest}. + *
  3. When in iterations, writes of shared objects can only occur in {@link + * IterationListener#onEpochWatermarkIncremented} and {@link + * IterationListener#onIterationTerminated}. + *
+ * + * @param inputs Input data streams. + * @param body User defined logic to build subgraph and to specify owners of every shared + * object. + * @return The output data streams. + */ + public static List> withSharedObjects( + List> inputs, SharedObjectsBody body) { + Preconditions.checkArgument(!inputs.isEmpty()); + StreamExecutionEnvironment env = inputs.get(0).getExecutionEnvironment(); + String coLocationID = "shared-storage-" + UUID.randomUUID(); + SharedObjectsContextImpl context = new SharedObjectsContextImpl(); + + DraftExecutionEnvironment draftEnv = + new DraftExecutionEnvironment(env, new SharedObjectsWrapper<>(context)); + List> draftSources = + inputs.stream() + .map( + dataStream -> + draftEnv.addDraftSource(dataStream, dataStream.getType())) + .collect(Collectors.toList()); + SharedObjectsBody.SharedObjectsBodyResult result = body.process(draftSources); + + List> draftOutputs = result.getOutputs(); + Map, AbstractSharedObjectsStreamOperator> rawOwnerMap = + result.getOwnerMap(); + Map, String> ownerMap = new HashMap<>(); + for (Descriptor descriptor : rawOwnerMap.keySet()) { + ownerMap.put(descriptor, rawOwnerMap.get(descriptor).getAccessorID()); + } + context.setOwnerMap(ownerMap); + + for (DataStream draftOutput : draftOutputs) { + draftEnv.addOperator(draftOutput.getTransformation()); + } + draftEnv.copyToActualEnvironment(); + + for (Transformation transformation : result.getCoLocatedTransformations()) { + DataStream ds = draftEnv.getActualStream(transformation.getId()); + ds.getTransformation().setCoLocationGroupKey(coLocationID); + } + + List> outputs = new ArrayList<>(); + for (DataStream draftOutput : draftOutputs) { + outputs.add(draftEnv.getActualStream(draftOutput.getId())); + } + return outputs; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java new file mode 100644 index 000000000..c4e4e0102 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.iteration.operator.OperatorWrapper; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.OutputTag; + +/** The operator wrapper for {@link AbstractSharedObjectsWrapperOperator}. */ +class SharedObjectsWrapper implements OperatorWrapper { + + /** Shared objects context. */ + private final SharedObjectsContextImpl context; + + public SharedObjectsWrapper(SharedObjectsContextImpl context) { + this.context = context; + } + + @Override + public StreamOperator wrap( + StreamOperatorParameters operatorParameters, + StreamOperatorFactory operatorFactory) { + Class operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (AbstractSharedObjectsStreamOperator.class.isAssignableFrom(operatorClass)) { + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new OneInputSharedObjectsWrapperOperator<>( + operatorParameters, operatorFactory, context); + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new TwoInputSharedObjectsWrapperOperator<>( + operatorParameters, operatorFactory, context); + } else { + return nowrap(operatorParameters, operatorFactory); + } + } else { + return nowrap(operatorParameters, operatorFactory); + } + } + + public StreamOperator nowrap( + StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory) { + return StreamOperatorFactoryUtil.createOperator( + operatorFactory, + (StreamTask) parameters.getContainingTask(), + parameters.getStreamConfig(), + parameters.getOutput(), + parameters.getOperatorEventDispatcher()) + .f0; + } + + @Override + public Class getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory operatorFactory) { + Class operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return OneInputSharedObjectsWrapperOperator.class; + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return TwoInputSharedObjectsWrapperOperator.class; + } else { + throw new UnsupportedOperationException( + "Unsupported operator class for shared objects wrapper: " + operatorClass); + } + } + + @Override + public KeySelector wrapKeySelector(KeySelector keySelector) { + return keySelector; + } + + @Override + public StreamPartitioner wrapStreamPartitioner(StreamPartitioner streamPartitioner) { + return streamPartitioner; + } + + @Override + public OutputTag wrapOutputTag(OutputTag outputTag) { + return outputTag; + } + + @Override + public TypeInformation getWrappedTypeInfo(TypeInformation typeInfo) { + return typeInfo; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java new file mode 100644 index 000000000..fedc8f3ac --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.Preconditions; + +import java.util.List; + +/** Wrapper for {@link TwoInputStreamOperator}. */ +class TwoInputSharedObjectsWrapperOperator + extends AbstractSharedObjectsWrapperOperator< + OUT, AbstractSharedObjectsTwoInputStreamOperator> + implements TwoInputStreamOperator, BoundedMultiInput { + + TwoInputSharedObjectsWrapperOperator( + StreamOperatorParameters parameters, + StreamOperatorFactory operatorFactory, + SharedObjectsContextImpl context) { + super(parameters, operatorFactory, context); + } + + @Override + protected List> getInputReadRequests(int inputId) { + Preconditions.checkArgument(0 == inputId || 1 == inputId); + if (0 == inputId) { + return wrappedOperator.readRequestsInProcessElement1(); + } else { + return wrappedOperator.readRequestsInProcessElement2(); + } + } + + @Override + protected void processCachedElementsBeforeEpochIncremented(int inputId) throws Exception { + Preconditions.checkArgument(0 == inputId || 1 == inputId); + if (inputId == 0) { + endInputX( + inputId, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); + } else { + endInputX( + inputId, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); + } + } + + @Override + public void processElement1(StreamRecord streamRecord) throws Exception { + processElementX( + streamRecord, + 0, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + processElementX( + streamRecord, + 1, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); + } + + @Override + public void endInput(int inputId) throws Exception { + Preconditions.checkArgument(1 == inputId || 2 == inputId); + if (1 == inputId) { + endInputX( + 0, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); + } else { + endInputX( + inputId - 1, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); + } + OperatorUtils.processOperatorOrUdfIfSatisfy( + wrappedOperator, + BoundedMultiInput.class, + boundedMultipleInput -> boundedMultipleInput.endInput(inputId)); + } + + @Override + public void processWatermark1(Watermark watermark) throws Exception { + processWatermarkX( + watermark, + 0, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); + } + + @Override + public void processWatermark2(Watermark watermark) throws Exception { + processWatermarkX( + watermark, + 1, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); + } + + @Override + public void processLatencyMarker1(LatencyMarker latencyMarker) throws Exception { + wrappedOperator.processLatencyMarker1(latencyMarker); + } + + @Override + public void processLatencyMarker2(LatencyMarker latencyMarker) throws Exception { + wrappedOperator.processLatencyMarker2(latencyMarker); + } + + @Override + public void processWatermarkStatus1(WatermarkStatus watermarkStatus) throws Exception { + wrappedOperator.processWatermarkStatus1(watermarkStatus); + } + + @Override + public void processWatermarkStatus2(WatermarkStatus watermarkStatus) throws Exception { + wrappedOperator.processWatermarkStatus2(watermarkStatus); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java index b284bbb3d..989b5a8a8 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.eventtime.WatermarkStrategy; import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.connector.source.Source; import org.apache.flink.connector.file.sink.FileSink; import org.apache.flink.connector.file.src.FileSource; @@ -323,4 +324,28 @@ public static Table loadModelData( env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData"); return tEnv.fromDataStream(modelDataStream); } + + /** + * Loads the model data from the given path using the model decoder. This overloaded version + * returns a table with only 1 column whose type is the class of the model data. + * + * @param tEnv A StreamTableEnvironment instance. + * @param path The parent directory of the model data file. + * @param modelDecoder The decoder used to decode the model data. + * @param typeInfo The type information of model data. + * @param The class type of the model data. + * @return The loaded model data. + */ + public static Table loadModelData( + StreamTableEnvironment tEnv, + String path, + SimpleStreamFormat modelDecoder, + TypeInformation typeInfo) { + StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); + Source source = + FileSource.forRecordStreamFormat(modelDecoder, FileUtils.getDataPath(path)).build(); + DataStream modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData", typeInfo); + return tEnv.fromDataStream(modelDataStream); + } } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java new file mode 100644 index 000000000..910096abb --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Tests the {@link SharedObjectsUtils}. */ +public class SharedObjectsUtilsTest { + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void testWithDataDeps() throws Exception { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + + DataStream data = env.fromSequence(1, 100); + List> outputs = + SharedObjectsUtils.withSharedObjects( + Collections.singletonList(data), new SharedObjectsBodyWithDataDeps()); + //noinspection unchecked + DataStream partitionSum = (DataStream) outputs.get(0); + DataStream allSum = + DataStreamUtils.reduce( + partitionSum, new SharedObjectsBodyWithDataDeps.SumReduceFunction()); + allSum.getTransformation().setParallelism(1); + //noinspection unchecked + List results = IteratorUtils.toList(allSum.executeAndCollect()); + Assert.assertEquals(Collections.singletonList(5050L), results); + } + + @Test + public void testWithoutDataDeps() throws Exception { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + + DataStream data = env.fromSequence(1, 100); + List> outputs = + SharedObjectsUtils.withSharedObjects( + Collections.singletonList(data), new SharedObjectsBodyWithoutDataDeps()); + //noinspection unchecked + DataStream added = (DataStream) outputs.get(0); + //noinspection unchecked + List results = IteratorUtils.toList(added.executeAndCollect()); + Collections.sort(results); + List expected = new ArrayList<>(); + for (long i = 1; i <= 100; i += 1) { + expected.add(i + 5050); + } + Assert.assertEquals(expected, results); + } + + @Test + public void testPotentialDeadlock() throws Exception { + Configuration configuration = new Configuration(); + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(configuration); + File stateFolder = tempFolder.newFolder(); + env.getCheckpointConfig() + .setCheckpointStorage( + new FileSystemCheckpointStorage( + new Path("file://", stateFolder.getPath()))); + final int n = 100; + // Set it to a large value, thus causing a deadlock situation. + final int len = 1 << 20; + DataStream data = + env.fromSequence(1, n).map(d -> RandomStringUtils.randomAlphabetic(len)); + List> outputs = + SharedObjectsUtils.withSharedObjects( + Collections.singletonList(data), new SharedObjectsBodyPotentialDeadlock()); + //noinspection unchecked + DataStream added = (DataStream) outputs.get(0); + added.addSink( + new SinkFunction() { + @Override + public void invoke(String value, Context context) { + Assert.assertEquals(2 * len, value.length()); + } + }); + env.execute(); + } + + static class SharedObjectsBodyWithDataDeps implements SharedObjectsBody { + private static final Descriptor SUM = + Descriptor.of("sum", LongSerializer.INSTANCE, 0L); + + @Override + public SharedObjectsBodyResult process(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + + AOperator aOp = new AOperator(); + SingleOutputStreamOperator afterAOp = + data.transform("a", TypeInformation.of(Long.class), aOp); + + BOperator bOp = new BOperator(); + SingleOutputStreamOperator afterBOp = + afterAOp.transform("b", TypeInformation.of(Long.class), bOp); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(SUM, aOp); + + return new SharedObjectsBodyResult( + Collections.singletonList(afterBOp), + Arrays.asList(afterAOp.getTransformation(), afterBOp.getTransformation()), + ownerMap); + } + + /** Operator A: add input elements to the shared {@link #SUM}. */ + static class AOperator extends AbstractSharedObjectsOneInputStreamOperator + implements BoundedOneInput { + + private transient long sum = 0; + + @Override + public void processElement(StreamRecord element) throws Exception { + sum += element.getValue(); + } + + @Override + public void endInput() throws Exception { + context.write(SUM, sum); + // Informs BOperator to get the value from shared {@link #SUM}. + output.collect(new StreamRecord<>(0L)); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } + } + + /** Operator B: when input ends, get the value from shared {@link #SUM}. */ + static class BOperator extends AbstractSharedObjectsOneInputStreamOperator { + + @Override + public void processElement(StreamRecord element) throws Exception { + output.collect(new StreamRecord<>(context.read(SUM.sameStep()))); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(SUM.sameStep()); + } + } + + static class SumReduceFunction implements ReduceFunction { + @Override + public Long reduce(Long value1, Long value2) { + return value1 + value2; + } + } + } + + static class SharedObjectsBodyWithoutDataDeps implements SharedObjectsBody { + private static final Descriptor SUM = Descriptor.of("sum", LongSerializer.INSTANCE); + + @Override + public SharedObjectsBodyResult process(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + DataStream sum = DataStreamUtils.reduce(data, Long::sum); + + COperator cOp = new COperator(); + SingleOutputStreamOperator afterCOp = + sum.broadcast().transform("c", TypeInformation.of(Long.class), cOp); + + DOperator dOp = new DOperator(); + SingleOutputStreamOperator afterDOp = + data.transform("d", TypeInformation.of(Long.class), dOp); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(SUM, cOp); + + return new SharedObjectsBodyResult( + Collections.singletonList(afterDOp), + Arrays.asList(afterCOp.getTransformation(), afterDOp.getTransformation()), + ownerMap); + } + + /** Operator C: set the shared object. */ + static class COperator extends AbstractSharedObjectsOneInputStreamOperator + implements BoundedOneInput { + private transient long sum; + + @Override + public void processElement(StreamRecord element) throws Exception { + sum = element.getValue(); + } + + @Override + public void endInput() throws Exception { + Thread.sleep(2 * 1000); + context.write(SUM, sum); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } + } + + /** Operator D: get the value from shared {@link #SUM}. */ + static class DOperator extends AbstractSharedObjectsOneInputStreamOperator { + + private Long sum; + + @Override + public void processElement(StreamRecord element) throws Exception { + if (null == sum) { + sum = context.read(SUM.sameStep()); + } + output.collect(new StreamRecord<>(sum + element.getValue())); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(SUM.sameStep()); + } + } + } + + static class SharedObjectsBodyPotentialDeadlock implements SharedObjectsBody { + private static final Descriptor LAST = + Descriptor.of("last", StringSerializer.INSTANCE); + + @Override + public SharedObjectsBodyResult process(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + DataStream sum = DataStreamUtils.reduce(data, (v1, v2) -> v2); + + EOperator eOp = new EOperator(); + SingleOutputStreamOperator afterCOp = + sum.broadcast().transform("e", TypeInformation.of(String.class), eOp); + + FOperator dOp = new FOperator(); + SingleOutputStreamOperator afterDOp = + data.transform("d", TypeInformation.of(String.class), dOp); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(LAST, eOp); + + return new SharedObjectsBodyResult( + Collections.singletonList(afterDOp), + Arrays.asList(afterCOp.getTransformation(), afterDOp.getTransformation()), + ownerMap); + } + + /** Operator E: set the shared object. */ + static class EOperator extends AbstractSharedObjectsOneInputStreamOperator + implements BoundedOneInput { + private transient String last; + + @Override + public void processElement(StreamRecord element) throws Exception { + last = element.getValue(); + } + + @Override + public void endInput() throws Exception { + Thread.sleep(2 * 1000); + context.write(LAST, last); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } + } + + /** Operator F: get the value from shared {@link #LAST}. */ + static class FOperator extends AbstractSharedObjectsOneInputStreamOperator { + + private String last; + + @Override + public void processElement(StreamRecord element) throws Exception { + if (null == last) { + last = context.read(LAST.sameStep()); + } + output.collect(new StreamRecord<>(last + element.getValue())); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(LAST.sameStep()); + } + } + } +} diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java new file mode 100644 index 000000000..84dba96a3 --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; + +/** Tests the serialization and deserialization from {@link DenseVectorSerializer}. */ +public class DenseVectorSerializerTest { + @Test + public void testSerializationDeserialization() throws IOException { + Random random = new Random(0); + int[] lens = new int[] {0, 100, 128, 500, 1024, 4096}; + + DenseVectorSerializer serializer = new DenseVectorSerializer(); + for (int len : lens) { + double[] arr = new double[len]; + for (int i = 0; i < len; i += 1) { + arr[i] = random.nextDouble(); + } + DenseVector expected = new DenseVector(arr); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + serializer.serialize(expected, new DataOutputViewStreamWrapper(baos)); + DenseVector actual = + serializer.deserialize( + new DataInputViewStreamWrapper( + new ByteArrayInputStream(baos.toByteArray()))); + Assert.assertEquals(expected, actual); + } + } +} diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java new file mode 100644 index 000000000..24225c069 --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; + +/** + * Tests the serialization and deserialization for the double array from {@link + * OptimizedDoublePrimitiveArraySerializer}. + */ +public class OptimizedDoublePrimitiveArraySerializerTest { + @Test + public void testSerializationDeserialization() throws IOException { + Random random = new Random(0); + int[] lens = new int[] {0, 100, 128, 500, 1024, 4096}; + + OptimizedDoublePrimitiveArraySerializer serializer = + new OptimizedDoublePrimitiveArraySerializer(); + for (int len : lens) { + double[] arr = new double[len]; + for (int i = 0; i < len; i += 1) { + arr[i] = random.nextDouble(); + } + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + serializer.serialize(arr, new DataOutputViewStreamWrapper(baos)); + double[] actual = + serializer.deserialize( + new DataInputViewStreamWrapper( + new ByteArrayInputStream(baos.toByteArray()))); + Assert.assertArrayEquals(arr, actual, 0.); + } + } +} diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java index ec97b48c6..33fc59cd4 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -51,6 +51,7 @@ import org.apache.commons.collections.IteratorUtils; import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; import java.io.DataInputStream; import java.io.DataOutputStream; @@ -324,4 +325,64 @@ public static DataFrame constructDataFrame( } return new DataFrame(columnNames, dataTypes, rowList); } + + /** + * Compare two lists of elements with the given comparator. Different from {@link + * org.apache.flink.test.util.TestBaseUtils#compareResultCollections}, the comparator is also + * used when comparing elements. + */ + public static void compareResultCollectionsWithComparator( + List expected, List actual, Comparator comparator) { + Assert.assertEquals(expected.size(), actual.size()); + expected.sort(comparator); + actual.sort(comparator); + for (int i = 0; i < expected.size(); i++) { + Assert.assertEquals(0, comparator.compare(expected.get(i), actual.get(i))); + } + } + + /** + * Compare two doubles with specified delta. If the differences between the two doubles are + * equal or less than delta, they are considered equal. Otherwise, they are compared with + * default comparison. + */ + public static class DoubleComparatorWithDelta implements Comparator { + private final double delta; + + public DoubleComparatorWithDelta(double delta) { + this.delta = delta; + } + + @Override + public int compare(Double o1, Double o2) { + return Math.abs(o1 - o2) <= delta ? 0 : Double.compare(o1, o2); + } + } + + /** + * Compare two dense vectors with specified delta. When comparing their values, {@link + * DoubleComparatorWithDelta} is used. + */ + public static class DenseVectorComparatorWithDelta implements Comparator { + private final DoubleComparatorWithDelta doubleComparatorWithDelta; + + public DenseVectorComparatorWithDelta(double delta) { + doubleComparatorWithDelta = new DoubleComparatorWithDelta(delta); + } + + @Override + public int compare(DenseVector o1, DenseVector o2) { + if (o1.size() != o2.size()) { + return Integer.compare(o1.size(), o2.size()); + } else { + for (int i = 0; i < o1.size(); i++) { + int cmp = doubleComparatorWithDelta.compare(o1.get(i), o2.get(i)); + if (cmp != 0) { + return cmp; + } + } + } + return 0; + } + } } diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 1773fc7d4..1b5f06539 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -138,6 +138,16 @@ under the License. test test-jar + + org.eclipse.collections + eclipse-collections-api + 11.1.0 + + + org.eclipse.collections + eclipse-collections + 11.1.0 + @@ -153,10 +163,22 @@ under the License. shade + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + com.github.wendykierp:JTransforms pl.edu.icm:JLargeArrays + org.eclipse.collections:eclipse-collections-api + org.eclipse.collections:eclipse-collections diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java new file mode 100644 index 000000000..c33614556 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.GBTRunner; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * An Estimator which implements the gradient boosting trees classification algorithm (Gradient Boosting). + * + *

The implementation has been inspired by advanced implementations like XGBoost and LightGBM. + * It supports features like regularized learning objective with second-order approximation, + * histogram-based and sparsity-aware split-finding algorithm. + * + *

The implementation of distributed system takes this work as a reference. Right now, we + * support horizontal partition of data and row-store storage of instances. + * + *

NOTE: Currently, some features are not supported yet: weighted input samples, early-stopping + * with validation set, encoding with leaf ids, etc. + */ +public class GBTClassifier + implements Estimator, + GBTClassifierParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public GBTClassifier() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + public static GBTClassifier load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + @Override + public GBTClassifierModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream modelData = GBTRunner.train(inputs[0], this); + DataStream> featureImportance = + GBTRunner.getFeatureImportance(modelData); + GBTClassifierModel model = new GBTClassifierModel(); + model.setModelData( + tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), + tEnv.fromDataStream(featureImportance) + .renameColumns($("f0").as("featureImportance"))); + ParamUtils.updateExistingParams(model, getParamMap()); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java new file mode 100644 index 000000000..7833e5d52 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.BaseGBTModel; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.math3.analysis.function.Sigmoid; +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; + +import java.io.IOException; +import java.util.Collections; + +/** A Model computed by {@link GBTClassifier}. */ +public class GBTClassifierModel extends BaseGBTModel + implements GBTClassifierModelParams { + + /** + * Loads model data from path. + * + * @param tEnv A StreamTableEnvironment instance. + * @param path Model path. + * @return GBT classification model. + */ + public static GBTClassifierModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + return BaseGBTModel.load(tEnv, path); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream inputStream = tEnv.toDataStream(inputs[0]); + final String broadcastModelKey = "broadcastModelKey"; + DataStream modelDataStream = GBTModelData.getModelDataStream(modelDataTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + Types.DOUBLE, + DenseVectorTypeInfo.INSTANCE, + DenseVectorTypeInfo.INSTANCE), + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), + getPredictionCol(), + getRawPredictionCol(), + getProbabilityCol())); + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + //noinspection unchecked + DataStream inputData = (DataStream) inputList.get(0); + return inputData.map( + new PredictLabelFunction(broadcastModelKey, getFeaturesCols()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + private static class PredictLabelFunction extends RichMapFunction { + + private static final Sigmoid sigmoid = new Sigmoid(); + + private final String broadcastModelKey; + private final String[] featuresCols; + private GBTModelData modelData; + + public PredictLabelFunction(String broadcastModelKey, String[] featuresCols) { + this.broadcastModelKey = broadcastModelKey; + this.featuresCols = featuresCols; + } + + @Override + public Row map(Row value) throws Exception { + if (null == modelData) { + modelData = + (GBTModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + } + IntDoubleHashMap features = modelData.rowToFeatures(value, featuresCols); + double logits = modelData.predictRaw(features); + double prob = sigmoid.value(logits); + return Row.join( + value, + Row.of( + logits >= 0. ? 1. : 0., + Vectors.dense(-logits, logits), + Vectors.dense(1 - prob, prob))); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java new file mode 100644 index 000000000..e4625e9e4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.ml.common.gbt.BaseGBTModelParams; +import org.apache.flink.ml.common.param.HasProbabilityCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; + +/** + * Parameters for {@link GBTClassifierModel}. + * + * @param The class type of this instance. + */ +public interface GBTClassifierModelParams + extends BaseGBTModelParams, HasRawPredictionCol, HasProbabilityCol {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java new file mode 100644 index 000000000..0640f56c5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.ml.common.gbt.BaseGBTParams; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Parameters for {@link GBTClassifier}. + * + * @param The class type of this instance. + */ +public interface GBTClassifierParams extends BaseGBTParams, GBTClassifierModelParams { + + Param LOSS_TYPE = + new StringParam( + "lossType", "Loss type.", "logistic", ParamValidators.inArray("logistic")); + + default String getLossType() { + return get(LOSS_TYPE); + } + + default T setLossType(String value) { + return set(LOSS_TYPE, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java new file mode 100644 index 000000000..11c1ff3c6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.HashMap; +import java.util.Map; + +/** Base model computed by {@link GBTClassifier} or {@link GBTRegressor}. */ +public abstract class BaseGBTModel> implements Model { + protected static final String MODEL_DATA_PATH = "model_data"; + protected static final String FEATURE_IMPORTANCE_PATH = "feature_importance"; + + protected final Map, Object> paramMap = new HashMap<>(); + protected Table modelDataTable; + protected Table featureImportanceTable; + + public BaseGBTModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + protected static > T load(StreamTableEnvironment tEnv, String path) + throws IOException { + T model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, + new Path(path, MODEL_DATA_PATH).toString(), + new GBTModelData.ModelDataDecoder()); + Table featureImportanceTable = + ReadWriteUtils.loadModelData( + tEnv, + new Path(path, FEATURE_IMPORTANCE_PATH).toString(), + new FeatureImportanceEncoderDecoder()); + return model.setModelData(modelDataTable, featureImportanceTable); + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable, featureImportanceTable}; + } + + @Override + public T setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 2); + modelDataTable = inputs[0]; + featureImportanceTable = inputs[1]; + //noinspection unchecked + return (T) this; + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + GBTModelData.getModelDataStream(modelDataTable), + new Path(path, MODEL_DATA_PATH).toString(), + new GBTModelData.ModelDataEncoder()); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) featureImportanceTable).getTableEnvironment(); + ReadWriteUtils.saveModelData( + tEnv.toDataStream( + featureImportanceTable, + DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())), + new Path(path, FEATURE_IMPORTANCE_PATH).toString(), + new FeatureImportanceEncoderDecoder()); + } + + private static class FeatureImportanceEncoderDecoder + extends SimpleStreamFormat> + implements Encoder> { + + final MapSerializer serializer = + new MapSerializer<>(StringSerializer.INSTANCE, DoubleSerializer.INSTANCE); + + @Override + public void encode(Map element, OutputStream stream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(stream); + serializer.serialize(element, dataOutputView); + } + + @Override + public Reader> createReader( + Configuration config, FSDataInputStream stream) throws IOException { + return new Reader>() { + @Override + public Map read() { + DataInputView source = new DataInputViewStreamWrapper(stream); + try { + return serializer.deserialize(source); + } catch (IOException e) { + return null; + } + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation> getProducedType() { + return Types.MAP(Types.STRING, Types.DOUBLE); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModelParams.java new file mode 100644 index 000000000..ec221472d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModelParams.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; +import org.apache.flink.ml.common.param.HasCategoricalCols; +import org.apache.flink.ml.common.param.HasFeaturesCols; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; + +/** + * Params of {@link GBTClassifierModel} and {@link GBTRegressorModel}. + * + *

The value `featureCols` can be either one column name of vector type, or multiple columns + * names of non-vector types. For the latter case, `categoricalCols` can be further set to + * specifying columns that need to be treated as categorical features. + * + * @param The class type of this instance. + */ +public interface BaseGBTModelParams + extends HasFeaturesCols, HasLabelCol, HasCategoricalCols, HasPredictionCol {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java new file mode 100644 index 000000000..a8c8272e4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasSeed; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Common parameters for GBT classifier and regressor. + * + *

NOTE: Features related to {@link #WEIGHT_COL}, {@link #LEAF_COL}, and {@link + * #VALIDATION_INDICATOR_COL} are not implemented yet. + * + * @param The class type of this instance. + */ +public interface BaseGBTParams + extends BaseGBTModelParams, HasWeightCol, HasMaxIter, HasSeed { + Param REG_LAMBDA = + new DoubleParam( + "regLambda", + "Regularization term for the number of leaves.", + 0., + ParamValidators.gtEq(0.)); + + Param REG_GAMMA = + new DoubleParam( + "regGamma", + "L2 regularization term for the weights of leaves.", + 1., + ParamValidators.gtEq(0)); + + Param LEAF_COL = + new StringParam("leafCol", "Predicted leaf index of each instance in each tree.", null); + + Param MAX_DEPTH = + new IntParam("maxDepth", "Maximum depth of the tree.", 5, ParamValidators.gtEq(1)); + + Param MAX_BINS = + new IntParam( + "maxBins", + "Maximum number of bins used for discretizing continuous features.", + 32, + ParamValidators.gtEq(2)); + + Param MIN_INSTANCES_PER_NODE = + new IntParam( + "minInstancesPerNode", + "Minimum number of instances each node must have. If a split causes the left or right child to have fewer instances than minInstancesPerNode, the split is invalid.", + 1, + ParamValidators.gtEq(1)); + + Param MIN_WEIGHT_FRACTION_PER_NODE = + new DoubleParam( + "minWeightFractionPerNode", + "Minimum fraction of the weighted sample count that each node must have. If a split causes the left or right child to have a smaller fraction of the total weight than minWeightFractionPerNode, the split is invalid.", + 0., + ParamValidators.gtEq(0.)); + + Param MIN_INFO_GAIN = + new DoubleParam( + "minInfoGain", + "Minimum information gain for a split to be considered valid.", + 0., + ParamValidators.gtEq(0.)); + + Param STEP_SIZE = + new DoubleParam( + "stepSize", + "Step size for shrinking the contribution of each estimator.", + 0.1, + ParamValidators.inRange(0., 1.)); + + Param SUBSAMPLING_RATE = + new DoubleParam( + "subsamplingRate", + "Fraction of the training data used for learning one tree.", + 1., + ParamValidators.inRange(0., 1.)); + + Param FEATURE_SUBSET_STRATEGY = + new StringParam( + "featureSubsetStrategy.", + "Fraction of the training data used for learning one tree. Supports \"auto\", \"all\", \"onethird\", \"sqrt\", \"log2\", (0.0 - 1.0], and [1 - n].", + "auto", + ParamValidators.notNull()); + + Param VALIDATION_INDICATOR_COL = + new StringParam( + "validationIndicatorCol", + "The name of the column that indicates whether each row is for training or for validation.", + null); + + Param VALIDATION_TOL = + new DoubleParam( + "validationTol", + "Threshold for early stopping when fitting with validation is used.", + .01, + ParamValidators.gtEq(0)); + + default double getRegLambda() { + return get(REG_LAMBDA); + } + + default T setRegLambda(Double value) { + return set(REG_LAMBDA, value); + } + + default double getRegGamma() { + return get(REG_GAMMA); + } + + default T setRegGamma(Double value) { + return set(REG_GAMMA, value); + } + + default String getLeafCol() { + return get(LEAF_COL); + } + + default T setLeafCol(String value) { + return set(LEAF_COL, value); + } + + default int getMaxDepth() { + return get(MAX_DEPTH); + } + + default T setMaxDepth(int value) { + return set(MAX_DEPTH, value); + } + + default int getMaxBins() { + return get(MAX_BINS); + } + + default T setMaxBins(int value) { + return set(MAX_BINS, value); + } + + default int getMinInstancesPerNode() { + return get(MIN_INSTANCES_PER_NODE); + } + + default T setMinInstancesPerNode(int value) { + return set(MIN_INSTANCES_PER_NODE, value); + } + + default double getMinWeightFractionPerNode() { + return get(MIN_WEIGHT_FRACTION_PER_NODE); + } + + default T setMinWeightFractionPerNode(Double value) { + return set(MIN_WEIGHT_FRACTION_PER_NODE, value); + } + + default double getMinInfoGain() { + return get(MIN_INFO_GAIN); + } + + default T setMinInfoGain(Double value) { + return set(MIN_INFO_GAIN, value); + } + + default double getStepSize() { + return get(STEP_SIZE); + } + + default T setStepSize(Double value) { + return set(STEP_SIZE, value); + } + + default double getSubsamplingRate() { + return get(SUBSAMPLING_RATE); + } + + default T setSubsamplingRate(Double value) { + return set(SUBSAMPLING_RATE, value); + } + + default String getFeatureSubsetStrategy() { + return get(FEATURE_SUBSET_STRATEGY); + } + + default T setFeatureSubsetStrategy(String value) { + return set(FEATURE_SUBSET_STRATEGY, value); + } + + default String getValidationIndicatorCol() { + return get(VALIDATION_INDICATOR_COL); + } + + default T setValidationIndicatorCol(String value) { + return set(VALIDATION_INDICATOR_COL, value); + } + + default double getValidationTol() { + return get(VALIDATION_TOL); + } + + default T setValidationTol(Double value) { + return set(VALIDATION_TOL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java new file mode 100644 index 000000000..34ec8d1f1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.operators.CacheDataCalcLocalHistsOperator; +import org.apache.flink.ml.common.gbt.operators.CalcLocalSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.PostSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.ReduceHistogramFunction; +import org.apache.flink.ml.common.gbt.operators.ReduceSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants; +import org.apache.flink.ml.common.gbt.operators.TerminationOperator; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.Descriptor; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsBody; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.types.Row; +import org.apache.flink.util.OutputTag; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Implements iteration body for boosting algorithms. This implementation uses horizontal partition + * of data and row-store storage of instances. + */ +class BoostIterationBody implements IterationBody { + private final BoostingStrategy strategy; + + public BoostIterationBody(BoostingStrategy strategy) { + this.strategy = strategy; + } + + private SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody( + List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + //noinspection unchecked + DataStream trainContext = (DataStream) inputs.get(1); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + + CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = + new CacheDataCalcLocalHistsOperator(strategy); + SingleOutputStreamOperator> localHists = + data.connect(trainContext.broadcast()) + .transform( + "CacheDataCalcLocalHists", + Types.TUPLE( + Types.INT, Types.INT, TypeInformation.of(Histogram.class)), + cacheDataCalcLocalHistsOp); + for (Descriptor s : SharedObjectsConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { + ownerMap.put(s, cacheDataCalcLocalHistsOp); + } + + DataStream> globalHists = + localHists.keyBy(d -> d.f1).flatMap(new ReduceHistogramFunction()); + + SingleOutputStreamOperator> localSplits = + globalHists.transform( + "CalcLocalSplits", + Types.TUPLE(Types.INT, Types.INT, TypeInformation.of(Split.class)), + new CalcLocalSplitsOperator()); + + DataStream> globalSplits = + localSplits + .keyBy(d -> d.f0) + .transform( + "ReduceGlobalSplits", + Types.TUPLE(Types.INT, TypeInformation.of(Split.class)), + new ReduceSplitsOperator()); + + PostSplitsOperator postSplitsOp = new PostSplitsOperator(); + SingleOutputStreamOperator updatedModelData = + globalSplits + .broadcast() + .transform("PostSplits", TypeInformation.of(Integer.class), postSplitsOp); + for (Descriptor descriptor : SharedObjectsConstants.OWNED_BY_POST_SPLITS_OP) { + ownerMap.put(descriptor, postSplitsOp); + } + + final OutputTag finalModelDataOutputTag = + new OutputTag<>("model_data", TypeInformation.of(GBTModelData.class)); + SingleOutputStreamOperator termination = + updatedModelData.transform( + "CheckTermination", + Types.INT, + new TerminationOperator(finalModelDataOutputTag)); + DataStream finalModelData = + termination.getSideOutput(finalModelDataOutputTag); + + return new SharedObjectsBody.SharedObjectsBodyResult( + Arrays.asList(updatedModelData, finalModelData, termination), + Arrays.asList( + localHists.getTransformation(), + localSplits.getTransformation(), + globalSplits.getTransformation(), + updatedModelData.getTransformation(), + termination.getTransformation()), + ownerMap); + } + + @Override + public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream data = dataStreams.get(0); + DataStream trainContext = variableStreams.get(0); + + List> outputs = + SharedObjectsUtils.withSharedObjects( + Arrays.asList(data, trainContext), this::sharedObjectsBody); + + DataStream updatedModelData = outputs.get(0); + DataStream finalModelData = outputs.get(1); + DataStream termination = outputs.get(2); + return new IterationBodyResult( + DataStreamList.of( + updatedModelData.flatMap( + (d, out) -> {}, TypeInformation.of(TrainContext.class))), + DataStreamList.of(finalModelData), + termination); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java new file mode 100644 index 000000000..32331cd31 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; + +import java.util.Arrays; +import java.util.Random; + +/** Some data utilities. */ +public class DataUtils { + + // Stores 4 values for one histogram bin, i.e., gradient, hessian, weight, and count. + public static final int BIN_SIZE = 4; + + public static void shuffle(int[] array, Random random) { + int n = array.length; + for (int i = 0; i < n; i += 1) { + int index = i + random.nextInt(n - i); + int tmp = array[index]; + array[index] = array[i]; + array[i] = tmp; + } + } + + public static int[] sample(int[] values, int numSamples, Random random) { + int n = values.length; + int[] sampled = new int[numSamples]; + + for (int i = 0; i < numSamples; i += 1) { + int index = i + random.nextInt(n - i); + sampled[i] = values[index]; + + int temp = values[i]; + values[i] = values[index]; + values[index] = temp; + } + return sampled; + } + + /** The mapping computation is from {@link KBinsDiscretizerModel}. */ + public static int findBin(double[] binEdges, double v) { + int index = Arrays.binarySearch(binEdges, v); + if (index < 0) { + // Computes the index to insert. + index = -index - 1; + // Puts it in the left bin. + index--; + } + return Math.max(Math.min(index, (binEdges.length - 2)), 0); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java new file mode 100644 index 000000000..49831f50f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.typeinfo.GBTModelDataSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.GBTModelDataTypeInfoFactory; +import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.IntObjectHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; + +/** + * Model data of {@link GBTClassifierModel} and {@link GBTRegressorModel}. + * + *

This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +@TypeInfo(GBTModelDataTypeInfoFactory.class) +public class GBTModelData { + + public String type; + public boolean isInputVector; + + public double prior; + public double stepSize; + + public List> allTrees; + public List featureNames; + public IntObjectHashMap> categoryToIdMaps; + public IntObjectHashMap featureIdToBinEdges; + public BitSet isCategorical; + + public GBTModelData() {} + + public GBTModelData( + String type, + boolean isInputVector, + double prior, + double stepSize, + List> allTrees, + List featureNames, + IntObjectHashMap> categoryToIdMaps, + IntObjectHashMap featureIdToBinEdges, + BitSet isCategorical) { + this.type = type; + this.isInputVector = isInputVector; + this.prior = prior; + this.stepSize = stepSize; + this.allTrees = allTrees; + this.featureNames = featureNames; + this.categoryToIdMaps = categoryToIdMaps; + this.featureIdToBinEdges = featureIdToBinEdges; + this.isCategorical = isCategorical; + } + + public static GBTModelData from(TrainContext trainContext, List> allTrees) { + List featureNames = new ArrayList<>(); + IntObjectHashMap> categoryToIdMaps = new IntObjectHashMap<>(); + IntObjectHashMap featureIdToBinEdges = new IntObjectHashMap<>(); + BitSet isCategorical = new BitSet(); + + FeatureMeta[] featureMetas = trainContext.featureMetas; + for (int k = 0; k < featureMetas.length; k += 1) { + FeatureMeta featureMeta = featureMetas[k]; + featureNames.add(featureMeta.name); + if (featureMeta instanceof FeatureMeta.CategoricalFeatureMeta) { + String[] categories = ((FeatureMeta.CategoricalFeatureMeta) featureMeta).categories; + ObjectIntHashMap categoryToId = new ObjectIntHashMap<>(); + for (int i = 0; i < categories.length; i += 1) { + categoryToId.put(categories[i], i); + } + categoryToIdMaps.put(k, categoryToId); + isCategorical.set(k); + } else if (featureMeta instanceof FeatureMeta.ContinuousFeatureMeta) { + featureIdToBinEdges.put( + k, ((FeatureMeta.ContinuousFeatureMeta) featureMeta).binEdges); + } + } + return new GBTModelData( + trainContext.strategy.taskType.name(), + trainContext.strategy.isInputVector, + trainContext.prior, + trainContext.strategy.stepSize, + allTrees, + featureNames, + categoryToIdMaps, + featureIdToBinEdges, + isCategorical); + } + + public static DataStream getModelDataStream(Table modelDataTable) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + return tEnv.toDataStream(modelDataTable).map(x -> x.getFieldAs(0)); + } + + /** The mapping computation is from {@link StringIndexerModel}. */ + private static int mapCategoricalFeature(ObjectIntHashMap categoryToId, Object v) { + String s; + if (v instanceof String) { + s = (String) v; + } else if (v instanceof Number) { + s = String.valueOf(v); + } else if (null == v) { + s = null; + } else { + throw new RuntimeException("Categorical column only supports string and numeric type."); + } + return categoryToId.getIfAbsent(s, categoryToId.size()); + } + + public IntDoubleHashMap rowToFeatures(Row row, String[] featuresCols) { + IntDoubleHashMap features = new IntDoubleHashMap(); + if (isInputVector) { + Vector vec = row.getFieldAs(featuresCols[0]); + SparseVector sv = vec.toSparse(); + for (int i = 0; i < sv.indices.length; i += 1) { + features.put(sv.indices[i], sv.values[i]); + } + } else { + for (int i = 0; i < featuresCols.length; i += 1) { + Object obj = row.getField(featuresCols[i]); + double v; + if (isCategorical.get(i)) { + v = mapCategoricalFeature(categoryToIdMaps.get(i), obj); + } else { + Number number = (Number) obj; + v = (null == number) ? Double.NaN : number.doubleValue(); + } + features.put(i, v); + } + } + return features; + } + + public double predictRaw(IntDoubleHashMap rawFeatures) { + double v = prior; + for (List treeNodes : allTrees) { + Node node = treeNodes.get(0); + while (!node.isLeaf) { + boolean goLeft = node.split.shouldGoLeft(rawFeatures); + node = goLeft ? treeNodes.get(node.left) : treeNodes.get(node.right); + } + v += stepSize * node.split.prediction; + } + return v; + } + + @Override + public String toString() { + return String.format( + "GBTModelData{type=%s, prior=%s, allTrees=%s, categoryToIdMaps=%s, featureIdToBinEdges=%s, isCategorical=%s}", + type, prior, allTrees, categoryToIdMaps, featureIdToBinEdges, isCategorical); + } + + /** Encoder for {@link GBTModelData}. */ + public static class ModelDataEncoder implements Encoder { + @Override + public void encode(GBTModelData modelData, OutputStream outputStream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); + final GBTModelDataSerializer serializer = GBTModelDataSerializer.INSTANCE; + serializer.serialize(modelData, dataOutputView); + } + } + + /** Decoder for {@link GBTModelData}. */ + public static class ModelDataDecoder extends SimpleStreamFormat { + @Override + public Reader createReader(Configuration config, FSDataInputStream stream) { + return new Reader() { + + private final GBTModelDataSerializer serializer = GBTModelDataSerializer.INSTANCE; + + @Override + public GBTModelData read() { + DataInputView source = new DataInputViewStreamWrapper(stream); + try { + return serializer.deserialize(source); + } catch (IOException e) { + return null; + } + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation getProducedType() { + return TypeInformation.of(GBTModelData.class); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java new file mode 100644 index 000000000..28aa36277 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.LossType; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** Runs a gradient boosting trees implementation. */ +public class GBTRunner { + + private static boolean isVectorType(TypeInformation typeInfo) { + return typeInfo instanceof DenseVectorTypeInfo + || typeInfo instanceof SparseVectorTypeInfo + || typeInfo instanceof VectorTypeInfo; + } + + public static DataStream train(Table data, BaseGBTParams estimator) { + String[] featuresCols = estimator.getFeaturesCols(); + TypeInformation[] featuresTypes = + Arrays.stream(featuresCols) + .map(d -> TableUtils.getTypeInfoByName(data.getResolvedSchema(), d)) + .toArray(TypeInformation[]::new); + for (int i = 0; i < featuresCols.length; i += 1) { + Preconditions.checkArgument( + null != featuresTypes[i], + String.format( + "Column name %s not existed in the input data.", featuresCols[i])); + } + + boolean isInputVector = featuresCols.length == 1 && isVectorType(featuresTypes[0]); + return train(data, getStrategy(estimator, isInputVector)); + } + + /** Trains a gradient boosting tree model with given data and parameters. */ + static DataStream train(Table dataTable, BoostingStrategy strategy) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + Tuple2> preprocessResult = + strategy.isInputVector + ? Preprocess.preprocessVecCol(dataTable, strategy) + : Preprocess.preprocessCols(dataTable, strategy); + dataTable = preprocessResult.f0; + DataStream featureMeta = preprocessResult.f1; + + DataStream data = tEnv.toDataStream(dataTable); + DataStream> labelSumCount = + DataStreamUtils.aggregate(data, new LabelSumCountFunction(strategy.labelCol)); + return boost(dataTable, strategy, featureMeta, labelSumCount); + } + + public static DataStream> getFeatureImportance( + DataStream modelData) { + return modelData + .map( + value -> { + Map featureImportanceMap = new HashMap<>(); + double sum = 0.; + for (List tree : value.allTrees) { + for (Node node : tree) { + if (node.isLeaf) { + continue; + } + featureImportanceMap.merge( + node.split.featureId, node.split.gain, Double::sum); + sum += node.split.gain; + } + } + if (sum > 0.) { + for (Map.Entry entry : + featureImportanceMap.entrySet()) { + entry.setValue(entry.getValue() / sum); + } + } + + List featureNames = value.featureNames; + return featureImportanceMap.entrySet().stream() + .collect( + Collectors.toMap( + d -> featureNames.get(d.getKey()), + Map.Entry::getValue)); + }) + .returns(Types.MAP(Types.STRING, Types.DOUBLE)); + } + + private static DataStream boost( + Table dataTable, + BoostingStrategy strategy, + DataStream featureMeta, + DataStream> labelSumCount) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + + final String featureMetaBcName = "featureMeta"; + final String labelSumCountBcName = "labelSumCount"; + Map> bcMap = new HashMap<>(); + bcMap.put(featureMetaBcName, featureMeta); + bcMap.put(labelSumCountBcName, labelSumCount); + + DataStream initTrainContext = + BroadcastUtils.withBroadcastStream( + Collections.singletonList( + tEnv.toDataStream(tEnv.fromValues(0), Integer.class)), + bcMap, + (inputs) -> { + //noinspection unchecked + DataStream input = (DataStream) (inputs.get(0)); + return input.map( + new InitTrainContextFunction( + featureMetaBcName, labelSumCountBcName, strategy)); + }); + + DataStream data = tEnv.toDataStream(dataTable); + DataStreamList dataStreamList = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(initTrainContext), + ReplayableDataStreamList.notReplay(data), + IterationConfig.newBuilder() + .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) + .build(), + new BoostIterationBody(strategy)); + return dataStreamList.get(0); + } + + public static BoostingStrategy getStrategy(BaseGBTParams estimator, boolean isInputVector) { + final Map, Object> paramMap = estimator.getParamMap(); + final Set> unsupported = + new HashSet<>( + Arrays.asList( + BaseGBTParams.WEIGHT_COL, + BaseGBTParams.LEAF_COL, + BaseGBTParams.VALIDATION_INDICATOR_COL)); + List> unsupportedButSet = + unsupported.stream() + .filter(d -> null != paramMap.get(d)) + .collect(Collectors.toList()); + if (!unsupportedButSet.isEmpty()) { + throw new UnsupportedOperationException( + String.format( + "Parameters %s are not supported yet right now.", + unsupportedButSet.stream() + .map(d -> d.name) + .collect(Collectors.joining(", ")))); + } + + BoostingStrategy strategy = new BoostingStrategy(); + strategy.featuresCols = estimator.getFeaturesCols(); + strategy.isInputVector = isInputVector; + strategy.labelCol = estimator.getLabelCol(); + strategy.categoricalCols = estimator.getCategoricalCols(); + + strategy.maxDepth = estimator.getMaxDepth(); + strategy.maxBins = estimator.getMaxBins(); + strategy.minInstancesPerNode = estimator.getMinInstancesPerNode(); + strategy.minWeightFractionPerNode = estimator.getMinWeightFractionPerNode(); + strategy.minInfoGain = estimator.getMinInfoGain(); + strategy.maxIter = estimator.getMaxIter(); + strategy.stepSize = estimator.getStepSize(); + strategy.seed = estimator.getSeed(); + strategy.subsamplingRate = estimator.getSubsamplingRate(); + strategy.featureSubsetStrategy = estimator.getFeatureSubsetStrategy(); + strategy.regGamma = estimator.getRegGamma(); + strategy.regLambda = estimator.getRegLambda(); + + String lossTypeStr; + if (estimator instanceof GBTClassifier) { + strategy.taskType = TaskType.CLASSIFICATION; + lossTypeStr = ((GBTClassifier) estimator).getLossType(); + } else if (estimator instanceof GBTRegressor) { + strategy.taskType = TaskType.REGRESSION; + lossTypeStr = ((GBTRegressor) estimator).getLossType(); + } else { + throw new IllegalArgumentException( + String.format( + "Unexpected type of estimator: %s.", + estimator.getClass().getSimpleName())); + } + strategy.lossType = LossType.valueOf(lossTypeStr.toUpperCase()); + strategy.maxNumLeaves = 1 << strategy.maxDepth - 1; + strategy.useMissing = true; + return strategy; + } + + private static class InitTrainContextFunction extends RichMapFunction { + private final String featureMetaBcName; + private final String labelSumCountBcName; + private final BoostingStrategy strategy; + + private InitTrainContextFunction( + String featureMetaBcName, String labelSumCountBcName, BoostingStrategy strategy) { + this.featureMetaBcName = featureMetaBcName; + this.labelSumCountBcName = labelSumCountBcName; + this.strategy = strategy; + } + + @Override + public TrainContext map(Integer value) { + TrainContext trainContext = new TrainContext(); + trainContext.strategy = strategy; + trainContext.featureMetas = + getRuntimeContext() + .getBroadcastVariable(featureMetaBcName) + .toArray(new FeatureMeta[0]); + if (!trainContext.strategy.isInputVector) { + Arrays.sort( + trainContext.featureMetas, + Comparator.comparing( + d -> ArrayUtils.indexOf(strategy.featuresCols, d.name))); + } + trainContext.numFeatures = trainContext.featureMetas.length; + trainContext.labelSumCount = + getRuntimeContext() + .>getBroadcastVariable(labelSumCountBcName) + .get(0); + return trainContext; + } + } + + private static class LabelSumCountFunction + implements AggregateFunction, Tuple2> { + + private final String labelCol; + + private LabelSumCountFunction(String labelCol) { + this.labelCol = labelCol; + } + + @Override + public Tuple2 createAccumulator() { + return Tuple2.of(0., 0L); + } + + @Override + public Tuple2 add(Row value, Tuple2 accumulator) { + double label = ((Number) value.getFieldAs(labelCol)).doubleValue(); + return Tuple2.of(accumulator.f0 + label, accumulator.f1 + 1); + } + + @Override + public Tuple2 getResult(Tuple2 accumulator) { + return accumulator; + } + + @Override + public Tuple2 merge(Tuple2 a, Tuple2 b) { + return Tuple2.of(a.f0 + b.f0, a.f1 + b.f1); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java new file mode 100644 index 000000000..5e43cf846 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData; +import org.apache.flink.ml.feature.stringindexer.StringIndexer; +import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; +import org.apache.flink.ml.feature.stringindexer.StringIndexerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.ApiExpression; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Arrays; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * Preprocesses input data table for gradient boosting trees algorithms. + * + *

Multiple non-vector columns or a single vector column can be specified for preprocessing. + * Values of these column(s) are mapped to integers inplace through discretizer or string indexer, + * and the meta information of column(s) are obtained. + */ +class Preprocess { + + /** + * Maps continuous and categorical columns to integers inplace using quantile discretizer and + * string indexer respectively, and obtains meta information for all columns. + */ + static Tuple2> preprocessCols( + Table dataTable, BoostingStrategy strategy) { + + final String[] relatedCols = ArrayUtils.add(strategy.featuresCols, strategy.labelCol); + dataTable = + dataTable.select( + Arrays.stream(relatedCols) + .map(Expressions::$) + .toArray(ApiExpression[]::new)); + + // Maps continuous columns to integers, and obtain corresponding discretizer model. + String[] continuousCols = + ArrayUtils.removeElements(strategy.featuresCols, strategy.categoricalCols); + Tuple2> continuousMappedDataAndModelData = + discretizeContinuousCols(dataTable, continuousCols, strategy.maxBins); + dataTable = continuousMappedDataAndModelData.f0; + DataStream continuousFeatureMeta = + buildContinuousFeatureMeta(continuousMappedDataAndModelData.f1, continuousCols); + + // Maps categorical columns to integers, and obtain string indexer model. + DataStream categoricalFeatureMeta; + if (strategy.categoricalCols.length > 0) { + String[] mappedCategoricalCols = + Arrays.stream(strategy.categoricalCols) + .map(d -> d + "_output") + .toArray(String[]::new); + StringIndexer stringIndexer = + new StringIndexer() + .setInputCols(strategy.categoricalCols) + .setOutputCols(mappedCategoricalCols) + .setHandleInvalid("keep"); + StringIndexerModel stringIndexerModel = stringIndexer.fit(dataTable); + dataTable = stringIndexerModel.transform(dataTable)[0]; + + categoricalFeatureMeta = + buildCategoricalFeatureMeta( + StringIndexerModelData.getModelDataStream( + stringIndexerModel.getModelData()[0]), + strategy.categoricalCols); + } else { + categoricalFeatureMeta = + continuousFeatureMeta + .flatMap((value, out) -> {}) + .returns(TypeInformation.of(FeatureMeta.class)); + } + + // Rename results columns. + ApiExpression[] dropColumnExprs = + Arrays.stream(strategy.categoricalCols) + .map(Expressions::$) + .toArray(ApiExpression[]::new); + ApiExpression[] renameColumnExprs = + Arrays.stream(strategy.categoricalCols) + .map(d -> $(d + "_output").as(d)) + .toArray(ApiExpression[]::new); + dataTable = dataTable.dropColumns(dropColumnExprs).renameColumns(renameColumnExprs); + + return Tuple2.of(dataTable, continuousFeatureMeta.union(categoricalFeatureMeta)); + } + + /** + * Maps features values in vectors to integers using quantile discretizer, and obtains meta + * information for all features. + */ + static Tuple2> preprocessVecCol( + Table dataTable, BoostingStrategy strategy) { + dataTable = dataTable.select($(strategy.featuresCols[0]), $(strategy.labelCol)); + Tuple2> mappedDataAndModelData = + discretizeVectorCol(dataTable, strategy.featuresCols[0], strategy.maxBins); + dataTable = mappedDataAndModelData.f0; + DataStream featureMeta = + buildContinuousFeatureMeta(mappedDataAndModelData.f1, null); + return Tuple2.of(dataTable, featureMeta); + } + + /** Builds {@link FeatureMeta} from {@link StringIndexerModelData}. */ + private static DataStream buildCategoricalFeatureMeta( + DataStream stringIndexerModelData, String[] cols) { + return stringIndexerModelData + .flatMap( + (d, out) -> { + Preconditions.checkArgument(d.stringArrays.length == cols.length); + for (int i = 0; i < cols.length; i += 1) { + out.collect( + FeatureMeta.categorical( + cols[i], + d.stringArrays[i].length, + d.stringArrays[i])); + } + }) + .returns(TypeInformation.of(FeatureMeta.class)); + } + + /** Builds {@link FeatureMeta} from {@link KBinsDiscretizerModelData}. */ + private static DataStream buildContinuousFeatureMeta( + DataStream discretizerModelData, String[] cols) { + return discretizerModelData + .flatMap( + (d, out) -> { + double[][] binEdges = d.binEdges; + for (int i = 0; i < binEdges.length; i += 1) { + String name = (null != cols) ? cols[i] : "_vec_f" + i; + out.collect( + FeatureMeta.continuous( + name, binEdges[i].length - 1, binEdges[i])); + } + }) + .returns(TypeInformation.of(FeatureMeta.class)); + } + + /** Discretizes continuous columns inplace, and obtains quantile discretizer model data. */ + @SuppressWarnings("checkstyle:RegexpSingleline") + private static Tuple2> discretizeContinuousCols( + Table dataTable, String[] continuousCols, int numBins) { + final StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + final int nCols = continuousCols.length; + + // Merges all continuous columns into a vector columns. + final String vectorCol = "_vec"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema()); + DataStream data = tEnv.toDataStream(dataTable, Row.class); + DataStream dataWithVectors = + data.map( + (row) -> { + double[] values = new double[nCols]; + for (int i = 0; i < nCols; i += 1) { + Number number = row.getFieldAs(continuousCols[i]); + // Null values are represented using `Double.NaN` in `DenseVector`. + values[i] = (null == number) ? Double.NaN : number.doubleValue(); + } + return Row.join(row, Row.of(Vectors.dense(values))); + }, + new RowTypeInfo( + ArrayUtils.add( + inputTypeInfo.getFieldTypes(), + DenseVectorTypeInfo.INSTANCE), + ArrayUtils.add(inputTypeInfo.getFieldNames(), vectorCol))); + + Tuple2> mappedDataAndModelData = + discretizeVectorCol(tEnv.fromDataStream(dataWithVectors), vectorCol, numBins); + DataStream discretized = tEnv.toDataStream(mappedDataAndModelData.f0); + + // Maps the result vector back to multiple continuous columns. + final String[] otherCols = + ArrayUtils.removeElements(inputTypeInfo.getFieldNames(), continuousCols); + final TypeInformation[] otherColTypes = + Arrays.stream(otherCols) + .map(inputTypeInfo::getTypeAt) + .toArray(TypeInformation[]::new); + final TypeInformation[] mappedColTypes = + Arrays.stream(continuousCols).map(d -> Types.INT).toArray(TypeInformation[]::new); + + DataStream mapped = + discretized.map( + (row) -> { + DenseVector vec = row.getFieldAs(vectorCol); + Integer[] ints = + Arrays.stream(vec.values) + .mapToObj(d -> (Integer) ((int) d)) + .toArray(Integer[]::new); + Row result = Row.project(row, otherCols); + for (int i = 0; i < ints.length; i += 1) { + result.setField(continuousCols[i], ints[i]); + } + return result; + }, + new RowTypeInfo( + ArrayUtils.addAll(otherColTypes, mappedColTypes), + ArrayUtils.addAll(otherCols, continuousCols))); + + return Tuple2.of(tEnv.fromDataStream(mapped), mappedDataAndModelData.f1); + } + + /** + * Discretize the vector column inplace using quantile discretizer, and obtains quantile + * discretizer model data.. + */ + private static Tuple2> discretizeVectorCol( + Table dataTable, String vectorCol, int numBins) { + final String outputCol = "_output_col"; + KBinsDiscretizer kBinsDiscretizer = + new KBinsDiscretizer() + .setInputCol(vectorCol) + .setOutputCol(outputCol) + .setStrategy("quantile") + .setNumBins(numBins); + KBinsDiscretizerModel model = kBinsDiscretizer.fit(dataTable); + Table discretizedDataTable = model.transform(dataTable)[0]; + return Tuple2.of( + discretizedDataTable + .dropColumns($(vectorCol)) + .renameColumns($(outputCol).as(vectorCol)), + KBinsDiscretizerModelData.getModelDataStream(model.getModelData()[0])); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java new file mode 100644 index 000000000..1e03e736a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; +import org.apache.flink.ml.feature.stringindexer.StringIndexer; +import org.apache.flink.ml.linalg.SparseVector; + +import javax.annotation.Nullable; + +import java.util.Arrays; + +/** + * Represents an instance including binned values of all features, weight, and label. + * + *

Categorical and continuous features are mapped to integers by {@link StringIndexer} and {@link + * KBinsDiscretizer}, respectively. Null values (`null` or `Double.NaN`) are also mapped to certain + * integers. + * + *

NOTE: When the input features are sparse, i.e., from {@link SparseVector}s, unseen indices are + * not stored in `features`. They should be handled separately. + */ +public class BinnedInstance { + + @Nullable public int[] featureIds; + public int[] featureValues; + public double weight; + public double label; + + public BinnedInstance() {} + + /** + * Get the index of `featureId` in `featureValues`. + * + * @param featureId The feature ID. + * @return The index in `featureValues`. If the index is negative, the corresponding feature is + * not stored in `featureValues`. + */ + public int getFeatureIndex(int featureId) { + return null == featureIds ? featureId : Arrays.binarySearch(featureIds, featureId); + } + + @Override + public String toString() { + return String.format( + "BinnedInstance{featureIds=%s, featureValues=%s, weight=%s, label=%s}", + Arrays.toString(featureIds), Arrays.toString(featureValues), weight, label); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java new file mode 100644 index 000000000..c64908241 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** Configurations for {@link org.apache.flink.ml.common.gbt.GBTRunner}. */ +public class BoostingStrategy implements Serializable { + + /** Indicates the task is classification or regression. */ + public TaskType taskType; + + /** + * Indicates whether the features are in one column of vector type or multiple columns of + * non-vector types. + */ + public boolean isInputVector; + + /** + * Names of features columns used for training. Contains only 1 column name when `isInputVector` + * is `true`. + */ + public String[] featuresCols; + + /** Name of label column. */ + public String labelCol; + + /** + * Names of columns which should be treated as categorical features, when `isInputVector` is + * `false`. + */ + public String[] categoricalCols; + + /** + * Max depth of the tree (root node is the 1st level). Depth 1 means 1 leaf node, i.e., the root + * node; Depth 2 means 1 internal node + 2 leaf nodes; etc. + */ + public int maxDepth; + + /** Maximum number of bins used for discretizing continuous features. */ + public int maxBins; + + /** + * Minimum number of instances each node must have. If a split causes the left or right child to + * have fewer instances than minInstancesPerNode, the split is invalid. + */ + public int minInstancesPerNode; + + /** + * Minimum fraction of the weighted sample count that each node must have. If a split causes the + * left or right child to have a smaller fraction of the total weight than + * minWeightFractionPerNode, the split is invalid. + * + *

NOTE: Weight column is not supported right now, so all samples have equal weights. + */ + public double minWeightFractionPerNode; + + /** Minimum information gain for a split to be considered valid. */ + public double minInfoGain; + + /** Maximum number of iterations of boosting, i.e. the number of trees in the final model. */ + public int maxIter; + + /** Step size for shrinking the contribution of each estimator. */ + public double stepSize; + + /** The random seed used in samples/features subsampling. */ + public long seed; + + /** Fraction of the training data used for learning one tree. */ + public double subsamplingRate; + + /** + * Fraction of the training data used for learning one tree. Supports "auto", "all", "onethird", + * "sqrt", "log2", (0.0 - 1.0], and [1 - n]. + */ + public String featureSubsetStrategy; + + /** Regularization term for the number of leaves. */ + public double regLambda; + + /** L2 regularization term for the weights of leaves. */ + public double regGamma; + + /** The type of loss used in boosting. */ + public LossType lossType; + + // Derived parameters. + /** Maximum number leaves. */ + public int maxNumLeaves; + /** Whether to consider missing values in the model. Always `true` right now. */ + public boolean useMissing; + + public BoostingStrategy() {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java new file mode 100644 index 000000000..f14caed61 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.ml.common.gbt.DataUtils; + +import java.io.Serializable; +import java.util.Arrays; + +/** Stores meta information of a feature. */ +public abstract class FeatureMeta { + + public String name; + public Type type; + // The bin index representing the missing values. + public int missingBin; + + public FeatureMeta() {} + + public FeatureMeta(String name, int missingBin, Type type) { + this.name = name; + this.missingBin = missingBin; + this.type = type; + } + + public static CategoricalFeatureMeta categorical( + String name, int missingBin, String[] categories) { + return new CategoricalFeatureMeta(name, missingBin, categories); + } + + public static ContinuousFeatureMeta continuous(String name, int missingBin, double[] binEdges) { + return new ContinuousFeatureMeta(name, missingBin, binEdges); + } + + /** + * Calculate number of bins used for this feature. + * + * @param useMissing Whether to assign an addition bin for missing values. + * @return The number of bins. + */ + public abstract int numBins(boolean useMissing); + + @Override + public String toString() { + return String.format( + "FeatureMeta{name='%s', type=%s, missingBin=%d}", name, type, missingBin); + } + + /** Indicates the feature type. */ + public enum Type implements Serializable { + CATEGORICAL, + CONTINUOUS + } + + /** Stores meta information for a categorical feature. */ + public static class CategoricalFeatureMeta extends FeatureMeta { + // Stores ordered categorical values. + public String[] categories; + + public CategoricalFeatureMeta() {} + + public CategoricalFeatureMeta(String name, int missingBin, String[] categories) { + super(name, missingBin, Type.CATEGORICAL); + this.categories = categories; + } + + @Override + public int numBins(boolean useMissing) { + return useMissing ? categories.length + 1 : categories.length; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + return obj instanceof CategoricalFeatureMeta + && this.type.equals(((CategoricalFeatureMeta) obj).type) + && (this.name.equals(((CategoricalFeatureMeta) obj).name)) + && (this.missingBin == ((CategoricalFeatureMeta) obj).missingBin) + && (Arrays.equals(this.categories, ((CategoricalFeatureMeta) obj).categories)); + } + + @Override + public String toString() { + return String.format( + "CategoricalFeatureMeta{categories=%s} %s", + Arrays.toString(categories), super.toString()); + } + } + + /** Stores meta information for a continuous feature. */ + public static class ContinuousFeatureMeta extends FeatureMeta { + // Stores the edges of bins. + public double[] binEdges; + // The bin index for value 0. + public int zeroBin; + + public ContinuousFeatureMeta() {} + + public ContinuousFeatureMeta(String name, int missingBin, double[] binEdges) { + super(name, missingBin, Type.CONTINUOUS); + this.binEdges = binEdges; + this.zeroBin = DataUtils.findBin(binEdges, 0.); + } + + @Override + public int numBins(boolean useMissing) { + return useMissing ? binEdges.length : binEdges.length - 1; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + return obj instanceof ContinuousFeatureMeta + && this.type.equals(((ContinuousFeatureMeta) obj).type) + && (this.name.equals(((ContinuousFeatureMeta) obj).name)) + && (this.missingBin == ((ContinuousFeatureMeta) obj).missingBin) + && (Arrays.equals(this.binEdges, ((ContinuousFeatureMeta) obj).binEdges)) + && (this.zeroBin == ((ContinuousFeatureMeta) obj).zeroBin); + } + + @Override + public String toString() { + return String.format( + "ContinuousFeatureMeta{binEdges=%s, zeroBin=%d} %s", + Arrays.toString(binEdges), zeroBin, super.toString()); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java new file mode 100644 index 000000000..063572625 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** + * The impurity introduced in XGBoost. + * + *

See: Introduction to + * Boosted Trees. + */ +public class HessianImpurity extends Impurity { + + // Regularization of the leaf number. + protected final double lambda; + // Regularization of leaf scores. + protected final double gamma; + // Total of instance gradients. + protected double totalGradients; + // Total of instance hessians. + protected double totalHessians; + + public HessianImpurity( + double lambda, + double gamma, + int numInstances, + double totalWeights, + double totalGradients, + double totalHessians) { + super(numInstances, totalWeights); + this.lambda = lambda; + this.gamma = gamma; + this.totalGradients = totalGradients; + this.totalHessians = totalHessians; + } + + @Override + public double prediction() { + return -totalGradients / (totalHessians + gamma); + } + + @Override + public double impurity() { + if (totalHessians + lambda == 0) { + return 0.; + } + return totalGradients * totalGradients / (totalHessians + lambda); + } + + @Override + public double gain(Impurity... others) { + double sum = 0.; + for (Impurity other : others) { + sum += other.impurity(); + } + return .5 * (sum - impurity()) - gamma; + } + + @Override + public HessianImpurity add(Impurity other) { + HessianImpurity impurity = (HessianImpurity) other; + this.numInstances += impurity.numInstances; + this.totalWeights += impurity.totalWeights; + this.totalGradients += impurity.totalGradients; + this.totalHessians += impurity.totalHessians; + return this; + } + + @Override + public HessianImpurity subtract(Impurity other) { + HessianImpurity impurity = (HessianImpurity) other; + this.numInstances -= impurity.numInstances; + this.totalWeights -= impurity.totalWeights; + this.totalGradients -= impurity.totalGradients; + this.totalHessians -= impurity.totalHessians; + return this; + } + + public void add(int numInstances, double weights, double gradients, double hessians) { + this.numInstances += numInstances; + this.totalWeights += weights; + this.totalGradients += gradients; + this.totalHessians += hessians; + } + + public void subtract(int numInstances, double weights, double gradients, double hessians) { + this.numInstances -= numInstances; + this.totalWeights -= weights; + this.totalGradients -= gradients; + this.totalHessians -= hessians; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java new file mode 100644 index 000000000..de6dc42bd --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.common.gbt.typeinfo.HistogramTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; + +/** + * This class stores values of histogram bins. + * + *

Note that only the part of {@link Histogram#hists} specified by {@link Histogram#slice} is + * valid. + */ +@TypeInfo(HistogramTypeInfoFactory.class) +public class Histogram implements Serializable { + // Stores values of histogram bins. + public double[] hists; + // Stores the valid slice of `hists`. + public Slice slice = new Slice(); + + public Histogram() {} + + public Histogram(double[] hists, Slice slice) { + this.hists = hists; + this.slice = slice; + } + + public Histogram accumulate(Histogram other) { + Preconditions.checkArgument(slice.size() == other.slice.size()); + for (int i = 0; i < slice.size(); i += 1) { + hists[slice.start + i] += other.hists[other.slice.start + i]; + } + return this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java new file mode 100644 index 000000000..4e30c57ab --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** The base class for calculating information gain from statistics. */ +public abstract class Impurity implements Cloneable, Serializable { + + // Number of instances. + protected int numInstances; + // Total of instance weights. + protected double totalWeights; + + public Impurity(int numInstances, double totalWeights) { + this.numInstances = numInstances; + this.totalWeights = totalWeights; + } + + /** + * Calculates the prediction. + * + * @return The prediction. + */ + public abstract double prediction(); + + /** + * Calculates the impurity. + * + * @return The impurity score. + */ + public abstract double impurity(); + + /** + * Calculate the information gain over other impurity instances, usually coming from splitting + * nodes. + * + * @param others Other impurity instances. + * @return The value of information gain. + */ + public abstract double gain(Impurity... others); + + /** + * Add statistics from other impurity instance. + * + * @param other The other impurity instance. + * @return The result after adding. + */ + public abstract Impurity add(Impurity other); + + /** + * Subtract statistics from other impurity instance. + * + * @param other The other impurity instance. + * @return The result after subtracting. + */ + public abstract Impurity subtract(Impurity other); + + /** + * Get the total of instance weights. + * + * @return The total of instance weights. + */ + public double getTotalWeights() { + return totalWeights; + } + + /** + * Get the number of instances. + * + * @return The number of instances. + */ + public int getNumInstances() { + return numInstances; + } + + @Override + public Impurity clone() { + try { + return (Impurity) super.clone(); + } catch (CloneNotSupportedException e) { + throw new IllegalStateException("Can not clone the impurity instance.", e); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java new file mode 100644 index 000000000..71a54c333 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** A node used in learning procedure. */ +public class LearningNode implements Serializable { + + // The node index in `currentTreeNodes` used in `PostSplitsOperator`. + public int nodeIndex; + // Slice of indices of bagging instances. + public Slice slice = new Slice(); + // Slice of indices of non-bagging instances. + public Slice oob = new Slice(); + // Depth of corresponding tree node. + public int depth; + + public LearningNode() {} + + public LearningNode(int nodeIndex, Slice slice, Slice oob, int depth) { + this.nodeIndex = nodeIndex; + this.slice = slice; + this.oob = oob; + this.depth = depth; + } + + @Override + public String toString() { + return String.format( + "LearningNode{nodeIndex=%s, slice=%s, oob=%s, depth=%d}", + nodeIndex, slice, oob, depth); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LossType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LossType.java new file mode 100644 index 000000000..58f047b57 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LossType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** Indicates the type of loss. */ +public enum LossType { + SQUARED, + LOGISTIC, +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java new file mode 100644 index 000000000..c83ab2a07 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.common.gbt.typeinfo.NodeTypeInfoFactory; + +import java.io.Serializable; + +/** + * Represents a tree node in a decision tree. + * + *

NOTE: This class should be used together with a linear indexable structure, e.g., a list or an + * array, which stores all tree nodes, because {@link #left} and {@link #right} are indices of nodes + * in the linear structure. + */ +@TypeInfo(NodeTypeInfoFactory.class) +public class Node implements Serializable { + + public Split split; + public boolean isLeaf = false; + public int left; + public int right; + + public Node() {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java new file mode 100644 index 000000000..4c2a1acef --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** Represents a slice of an indexable linear structure, like an array. */ +public final class Slice implements Serializable { + + public int start; + public int end; + + public Slice() {} + + public Slice(int start, int end) { + this.start = start; + this.end = end; + } + + public int size() { + return end - start; + } + + @Override + public String toString() { + return String.format("Slice{start=%d, end=%d}", start, end); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java new file mode 100644 index 000000000..1baeb35f4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.common.gbt.typeinfo.SplitTypeInfoFactory; + +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; + +import java.util.BitSet; + +/** Stores a split on a feature. */ +@TypeInfo(SplitTypeInfoFactory.class) +public abstract class Split { + public static final double INVALID_GAIN = 0.0; + + // Stores the feature index of this split. + public final int featureId; + + // Stores impurity gain. A value of `INVALID_GAIN` indicates this split is invalid. + public final double gain; + + // Bin index for missing values of this feature. + public final int missingBin; + // Whether the missing values should go left. + public final boolean missingGoLeft; + + // The prediction value if this split is invalid. + public final double prediction; + + public Split( + int featureId, double gain, int missingBin, boolean missingGoLeft, double prediction) { + this.featureId = featureId; + this.gain = gain; + this.missingBin = missingBin; + this.missingGoLeft = missingGoLeft; + this.prediction = prediction; + } + + public Split accumulate(Split other) { + if (gain < other.gain) { + return other; + } else if (gain == other.gain) { + if (featureId < other.featureId) { + return other; + } + } + return this; + } + + /** + * Test the binned instance should go to the left child or the right child. + * + * @param binnedInstance The instance after binned. + * @return True if the instance should go to the left child. + */ + public abstract boolean shouldGoLeft(BinnedInstance binnedInstance); + + /** + * Test the raw features should go to the left child or the right child. In the raw features, + * the categorical values are mapped to integers, while the continuous values are kept unmapped. + * + * @param rawFeatures The feature map from feature indices to values. + * @return True if the raw features should go to the left child. + */ + public abstract boolean shouldGoLeft(IntDoubleHashMap rawFeatures); + + public boolean isValid() { + return gain != INVALID_GAIN; + } + + /** Stores a split on a continuous feature. */ + public static class ContinuousSplit extends Split { + + /** + * Stores the threshold that one continuous feature should go the left or right. Before + * splitting the node, the threshold is the bin index. After that, the threshold is replaced + * with the actual value of the bin edge. + */ + public double threshold; + + // True if treat unseen values as missing values, otherwise treat them as 0s. + public boolean isUnseenMissing; + + // Bin index for 0 values. + public int zeroBin; + + public ContinuousSplit( + int featureIndex, + double gain, + int missingBin, + boolean missingGoLeft, + double prediction, + double threshold, + boolean isUnseenMissing, + int zeroBin) { + super(featureIndex, gain, missingBin, missingGoLeft, prediction); + this.threshold = threshold; + this.isUnseenMissing = isUnseenMissing; + this.zeroBin = zeroBin; + } + + public static ContinuousSplit invalid(double prediction) { + return new ContinuousSplit(0, INVALID_GAIN, 0, false, prediction, 0., false, 0); + } + + @Override + public boolean shouldGoLeft(BinnedInstance binnedInstance) { + int index = binnedInstance.getFeatureIndex(featureId); + if (index < 0 && isUnseenMissing) { + return missingGoLeft; + } + int binId = index >= 0 ? binnedInstance.featureValues[index] : zeroBin; + return binId == missingBin ? missingGoLeft : binId <= threshold; + } + + @Override + public boolean shouldGoLeft(IntDoubleHashMap rawFeatures) { + if (!rawFeatures.containsKey(featureId) && isUnseenMissing) { + return missingGoLeft; + } + double v = rawFeatures.getIfAbsent(featureId, 0.); + return Double.isNaN(v) ? missingGoLeft : v < threshold; + } + } + + /** Stores a split on a categorical feature. */ + public static class CategoricalSplit extends Split { + // Stores the indices of categorical values that should go to the left child. + public final BitSet categoriesGoLeft; + + public CategoricalSplit( + int featureId, + double gain, + int missingBin, + boolean missingGoLeft, + double prediction, + BitSet categoriesGoLeft) { + super(featureId, gain, missingBin, missingGoLeft, prediction); + this.categoriesGoLeft = categoriesGoLeft; + } + + public static CategoricalSplit invalid(double prediction) { + return new CategoricalSplit(0, INVALID_GAIN, 0, false, prediction, new BitSet()); + } + + @Override + public boolean shouldGoLeft(BinnedInstance binnedInstance) { + int index = binnedInstance.getFeatureIndex(featureId); + if (index < 0) { + return missingGoLeft; + } + int binId = binnedInstance.featureValues[index]; + return binId == missingBin ? missingGoLeft : categoriesGoLeft.get(binId); + } + + @Override + public boolean shouldGoLeft(IntDoubleHashMap rawFeatures) { + if (!rawFeatures.containsKey(featureId)) { + return missingGoLeft; + } + return categoriesGoLeft.get((int) rawFeatures.get(featureId)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java new file mode 100644 index 000000000..3d375823e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** Indicates the type of task. */ +public enum TaskType { + CLASSIFICATION, + REGRESSION, +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java new file mode 100644 index 000000000..b66a78b7c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.lossfunc.LossFunc; + +import java.io.Serializable; +import java.util.Random; + +/** + * Stores necessary static context information for training. Subtasks of co-located operators + * scheduled in a same TaskManager share a same context. + */ +public class TrainContext implements Serializable { + /** Subtask ID of co-located operators. */ + public int subtaskId; + + /** Number of subtasks of co-located operators. */ + public int numSubtasks; + + /** Configurations for the boosting. */ + public BoostingStrategy strategy; + + /** Number of instances. */ + public int numInstances; + + /** Number of bagging instances used for training one tree. */ + public int numBaggingInstances; + + /** Randomizer for sampling instances. */ + public Random instanceRandomizer; + + /** Number of features. */ + public int numFeatures; + + /** Number of bagging features tested for splitting one node. */ + public int numBaggingFeatures; + + /** Randomizer for sampling features. */ + public Random featureRandomizer; + + /** Meta information of every feature. */ + public FeatureMeta[] featureMetas; + + /** Number of bins for every feature. */ + public int[] numFeatureBins; + + /** Sum and count of labels of all samples. */ + public Tuple2 labelSumCount; + + /** The prior value for prediction. */ + public double prior; + + /** The loss function. */ + public LossFunc loss; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java new file mode 100644 index 000000000..a36748aa4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsTwoInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ALL_TREES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.HAS_INITED_TREE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.INSTANCES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LAYER; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NEED_INIT_TREE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NODE_FEATURE_PAIRS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.PREDS_GRADS_HESSIANS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ROOT_LEARNING_NODE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SHUFFLED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SWAPPED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; + +/** + * Calculates local histograms for local data partition. + * + *

This operator only has input elements in the first round, including data instances and raw + * training context. There will be no input elements in other rounds. The output elements are tuples + * of (subtask index, (nodeId, featureId) pair index, Histogram). + */ +public class CacheDataCalcLocalHistsOperator + extends AbstractSharedObjectsTwoInputStreamOperator< + Row, TrainContext, Tuple3> + implements IterationListener> { + + private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; + private static final String HIST_BUILDER_STATE_NAME = "hist_builder"; + + private final BoostingStrategy strategy; + + // States of local data. + private transient TrainContext rawTrainContext; + private transient ListStateWithCache instancesCollecting; + private transient ListStateWithCache treeInitializerState; + private transient TreeInitializer treeInitializer; + private transient ListStateWithCache histBuilderState; + private transient HistBuilder histBuilder; + + public CacheDataCalcLocalHistsOperator(BoostingStrategy strategy) { + super(); + this.strategy = strategy; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + instancesCollecting = + new ListStateWithCache<>( + BinnedInstanceSerializer.INSTANCE, + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + treeInitializerState = + new ListStateWithCache<>( + new KryoSerializer<>(TreeInitializer.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + treeInitializer = + OperatorStateUtils.getUniqueElement( + treeInitializerState, TREE_INITIALIZER_STATE_NAME) + .orElse(null); + histBuilderState = + new ListStateWithCache<>( + new KryoSerializer<>(HistBuilder.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + histBuilder = + OperatorStateUtils.getUniqueElement(histBuilderState, HIST_BUILDER_STATE_NAME) + .orElse(null); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + instancesCollecting.snapshotState(context); + treeInitializerState.snapshotState(context); + histBuilderState.snapshotState(context); + } + + @Override + public void processElement1(StreamRecord streamRecord) throws Exception { + Row row = streamRecord.getValue(); + BinnedInstance instance = new BinnedInstance(); + instance.weight = 1.; + instance.label = row.getFieldAs(strategy.labelCol).doubleValue(); + + if (strategy.isInputVector) { + Vector vec = row.getFieldAs(strategy.featuresCols[0]); + SparseVector sv = vec.toSparse(); + instance.featureIds = sv.indices.length == sv.size() ? null : sv.indices; + instance.featureValues = Arrays.stream(sv.values).mapToInt(d -> (int) d).toArray(); + } else { + instance.featureValues = + Arrays.stream(strategy.featuresCols) + .mapToInt(col -> ((Number) row.getFieldAs(col)).intValue()) + .toArray(); + } + instancesCollecting.add(instance); + } + + @Override + public List> readRequestsInProcessElement1() { + return Collections.emptyList(); + } + + @Override + public void processElement2(StreamRecord streamRecord) { + rawTrainContext = streamRecord.getValue(); + } + + @Override + public List> readRequestsInProcessElement2() { + return Collections.emptyList(); + } + + public void onEpochWatermarkIncremented( + int epochWatermark, Context c, Collector> out) + throws Exception { + if (0 == epochWatermark) { + // Initializes local state in first round. + BinnedInstance[] instances = + (BinnedInstance[]) + IteratorUtils.toArray( + instancesCollecting.get().iterator(), BinnedInstance.class); + context.write(INSTANCES, instances); + instancesCollecting.clear(); + + TrainContext trainContext = + new TrainContextInitializer(strategy) + .init( + rawTrainContext, + getRuntimeContext().getIndexOfThisSubtask(), + getRuntimeContext().getNumberOfParallelSubtasks(), + instances); + context.write(TRAIN_CONTEXT, trainContext); + + treeInitializer = new TreeInitializer(trainContext); + treeInitializerState.update(Collections.singletonList(treeInitializer)); + histBuilder = new HistBuilder(trainContext); + histBuilderState.update(Collections.singletonList(histBuilder)); + + } else { + context.renew(TRAIN_CONTEXT); + context.renew(INSTANCES); + } + + TrainContext trainContext = context.read(TRAIN_CONTEXT.sameStep()); + Preconditions.checkArgument( + getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); + BinnedInstance[] instances = context.read(INSTANCES.sameStep()); + + double[] pgh = new double[0]; + boolean needInitTree = true; + int numTrees = 0; + if (epochWatermark > 0) { + pgh = context.read(PREDS_GRADS_HESSIANS.prevStep()); + needInitTree = context.read(NEED_INIT_TREE.prevStep()); + numTrees = context.read(ALL_TREES.prevStep()).size(); + } + // In the first round, use prior as the predictions. + if (0 == pgh.length) { + pgh = new double[instances.length * 3]; + double prior = trainContext.prior; + LossFunc loss = trainContext.loss; + for (int i = 0; i < instances.length; i += 1) { + double label = instances[i].label; + pgh[3 * i] = prior; + pgh[3 * i + 1] = loss.gradient(prior, label); + pgh[3 * i + 2] = loss.hessian(prior, label); + } + } + + int[] indices; + List layer; + if (needInitTree) { + // When last tree is finished, initializes a new tree, and shuffle instance + // indices. + treeInitializer.init(numTrees, d -> context.write(SHUFFLED_INDICES, d)); + LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); + indices = context.read(SHUFFLED_INDICES.sameStep()); + layer = Collections.singletonList(rootLearningNode); + context.write(ROOT_LEARNING_NODE, rootLearningNode); + context.write(HAS_INITED_TREE, true); + } else { + // Otherwise, uses the swapped instance indices. + indices = context.read(SWAPPED_INDICES.prevStep()); + layer = context.read(LAYER.prevStep()); + context.write(SHUFFLED_INDICES, new int[0]); + context.write(HAS_INITED_TREE, false); + context.renew(ROOT_LEARNING_NODE); + } + + histBuilder.build( + layer, indices, instances, pgh, d -> context.write(NODE_FEATURE_PAIRS, d), out); + } + + @Override + public void onIterationTerminated( + Context c, Collector> collector) { + instancesCollecting.clear(); + treeInitializerState.clear(); + histBuilderState.clear(); + + context.write(INSTANCES, new BinnedInstance[0]); + context.write(SHUFFLED_INDICES, new int[0]); + context.write(NODE_FEATURE_PAIRS, new int[0]); + } + + @Override + public void close() throws Exception { + instancesCollecting.clear(); + treeInitializerState.clear(); + histBuilderState.clear(); + super.close(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java new file mode 100644 index 000000000..9a33b95dc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LAYER; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LEAVES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NODE_FEATURE_PAIRS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ROOT_LEARNING_NODE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; + +/** + * Calculates best splits from histograms for (nodeId, featureId) pairs. + * + *

The input elements are tuples of ((nodeId, featureId) pair index, Histogram). The output + * elements are tuples of (node index, (nodeId, featureId) pair index, Split). + */ +public class CalcLocalSplitsOperator + extends AbstractSharedObjectsOneInputStreamOperator< + Tuple2, Tuple3> { + + private static final Logger LOG = LoggerFactory.getLogger(CalcLocalSplitsOperator.class); + private static final String SPLIT_FINDER_STATE_NAME = "split_finder"; + // States of local data. + private transient ListStateWithCache splitFinderState; + private transient SplitFinder splitFinder; + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + splitFinderState = + new ListStateWithCache<>( + new KryoSerializer<>(SplitFinder.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + splitFinder = + OperatorStateUtils.getUniqueElement(splitFinderState, SPLIT_FINDER_STATE_NAME) + .orElse(null); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + splitFinderState.snapshotState(context); + } + + @Override + public void processElement(StreamRecord> element) throws Exception { + if (null == splitFinder) { + splitFinder = new SplitFinder(context.read(TRAIN_CONTEXT.nextStep())); + splitFinderState.update(Collections.singletonList(splitFinder)); + } + + Tuple2 value = element.getValue(); + int pairId = value.f0; + Histogram histogram = value.f1; + LOG.debug("Received histogram for pairId: {}", pairId); + + List layer = context.read(LAYER.sameStep()); + if (layer.isEmpty()) { + layer = Collections.singletonList(context.read(ROOT_LEARNING_NODE.nextStep())); + } + + int[] nodeFeaturePairs = context.read(NODE_FEATURE_PAIRS.nextStep()); + int nodeId = nodeFeaturePairs[2 * pairId]; + int featureId = nodeFeaturePairs[2 * pairId + 1]; + LearningNode node = layer.get(nodeId); + + Split bestSplit = + splitFinder.calc( + node, featureId, context.read(LEAVES.sameStep()).size(), histogram); + output.collect(new StreamRecord<>(Tuple3.of(nodeId, pairId, bestSplit))); + LOG.debug("Output split for pairId: {}", pairId); + } + + @Override + public List> readRequestsInProcessElement() { + return Arrays.asList( + TRAIN_CONTEXT.nextStep(), + LAYER.sameStep(), + ROOT_LEARNING_NODE.nextStep(), + NODE_FEATURE_PAIRS.nextStep(), + LEAVES.sameStep()); + } + + @Override + public void close() throws Exception { + super.close(); + splitFinderState.clear(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java new file mode 100644 index 000000000..9ccc62211 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.DataUtils; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.List; +import java.util.Random; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; + +class HistBuilder { + private static final Logger LOG = LoggerFactory.getLogger(HistBuilder.class); + + private final int subtaskId; + + private final int numFeatures; + private final int[] numFeatureBins; + private final FeatureMeta[] featureMetas; + + private final int numBaggingFeatures; + private final Random featureRandomizer; + private final int[] featureIndicesPool; + + private final boolean isInputVector; + private final int maxDepth; + + public HistBuilder(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + + numFeatures = trainContext.numFeatures; + numFeatureBins = trainContext.numFeatureBins; + featureMetas = trainContext.featureMetas; + + numBaggingFeatures = trainContext.numBaggingFeatures; + featureRandomizer = trainContext.featureRandomizer; + featureIndicesPool = IntStream.range(0, trainContext.numFeatures).toArray(); + + isInputVector = trainContext.strategy.isInputVector; + maxDepth = trainContext.strategy.maxDepth; + } + + /** Calculate local histograms for nodes in current layer of tree. */ + void build( + List layer, + int[] indices, + BinnedInstance[] instances, + double[] pgh, + Consumer nodeFeaturePairsSetter, + Collector> out) { + LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); + int numNodes = layer.size(); + + // Generates (nodeId, featureId) pairs that are required to build histograms. + int[][] nodeToFeatures = new int[numNodes][]; + IntArrayList nodeFeaturePairs = new IntArrayList(numNodes * numBaggingFeatures * 2); + for (int k = 0; k < numNodes; k += 1) { + LearningNode node = layer.get(k); + if (node.depth == maxDepth) { + // Ignores the results, just to consume the randomizer. + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + // No need to calculate histograms for features, only sum of gradients and hessians + // are needed. Uses `numFeatures` to indicate this special "feature". + nodeToFeatures[k] = new int[] {numFeatures}; + } else { + nodeToFeatures[k] = + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + Arrays.sort(nodeToFeatures[k]); + } + for (int featureId : nodeToFeatures[k]) { + nodeFeaturePairs.add(k); + nodeFeaturePairs.add(featureId); + } + } + nodeFeaturePairsSetter.accept(nodeFeaturePairs.toArray()); + + // Calculates histograms for (nodeId, featureId) pairs. + HistBuilderImpl builderImpl = + new HistBuilderImpl( + layer, + maxDepth, + numFeatures, + numFeatureBins, + nodeToFeatures, + indices, + instances, + pgh); + builderImpl.init(isInputVector, featureMetas); + builderImpl.calcHistsForPairs(subtaskId, out); + + LOG.info("subtaskId: {}, {} end", subtaskId, HistBuilder.class.getSimpleName()); + } + + static class HistBuilderImpl { + private final List layer; + private final int maxDepth; + private final int numFeatures; + private final int[] numFeatureBins; + private final int[][] nodeToFeatures; + private final int[] indices; + private final BinnedInstance[] instances; + private final double[] pgh; + + private int[] featureDefaultVal; + + public HistBuilderImpl( + List layer, + int maxDepth, + int numFeatures, + int[] numFeatureBins, + int[][] nodeToFeatures, + int[] indices, + BinnedInstance[] instances, + double[] pgh) { + this.layer = layer; + this.maxDepth = maxDepth; + this.numFeatures = numFeatures; + this.numFeatureBins = numFeatureBins; + this.nodeToFeatures = nodeToFeatures; + this.indices = indices; + this.instances = instances; + this.pgh = pgh; + Preconditions.checkArgument(numFeatureBins.length == numFeatures + 1); + } + + private static void calcHistsForDefaultBin( + int defaultVal, + int featureOffset, + int numBins, + double[] totalHists, + double[] hists, + int nodeOffset) { + int defaultValIndex = (nodeOffset + featureOffset + defaultVal) * BIN_SIZE; + hists[defaultValIndex] = totalHists[0]; + hists[defaultValIndex + 1] = totalHists[1]; + hists[defaultValIndex + 2] = totalHists[2]; + hists[defaultValIndex + 3] = totalHists[3]; + for (int i = 0; i < numBins; i += 1) { + if (i != defaultVal) { + int index = (nodeOffset + featureOffset + i) * BIN_SIZE; + add( + hists, + nodeOffset + featureOffset, + defaultVal, + -hists[index], + -hists[index + 1], + -hists[index + 2], + -hists[index + 3]); + } + } + } + + private static void add( + double[] hists, int offset, int val, double d0, double d1, double d2, double d3) { + int index = (offset + val) * BIN_SIZE; + hists[index] += d0; + hists[index + 1] += d1; + hists[index + 2] += d2; + hists[index + 3] += d3; + } + + private void init(boolean isInputVector, FeatureMeta[] featureMetas) { + featureDefaultVal = new int[numFeatures]; + for (int i = 0; i < numFeatures; i += 1) { + FeatureMeta d = featureMetas[i]; + featureDefaultVal[i] = + isInputVector && d instanceof FeatureMeta.ContinuousFeatureMeta + ? ((FeatureMeta.ContinuousFeatureMeta) d).zeroBin + : d.missingBin; + } + } + + private void calcTotalHists(LearningNode node, double[] totalHists, int offset) { + for (int i = node.slice.start; i < node.slice.end; i += 1) { + int instanceId = indices[i]; + BinnedInstance binnedInstance = instances[instanceId]; + double weight = binnedInstance.weight; + double gradient = pgh[3 * instanceId + 1]; + double hessian = pgh[3 * instanceId + 2]; + add(totalHists, offset, 0, gradient, hessian, weight, 1.); + } + } + + private void calcHistsForNonDefaultBins( + LearningNode node, + boolean allFeatureValid, + BitSet featureValid, + int[] featureOffset, + double[] hists, + int nodeOffset) { + for (int i = node.slice.start; i < node.slice.end; i += 1) { + int instanceId = indices[i]; + BinnedInstance binnedInstance = instances[instanceId]; + double weight = binnedInstance.weight; + double gradient = pgh[3 * instanceId + 1]; + double hessian = pgh[3 * instanceId + 2]; + + if (null == binnedInstance.featureIds) { + for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { + if (allFeatureValid || featureValid.get(j)) { + add( + hists, + nodeOffset + featureOffset[j], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); + } + } + } else { + for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { + int featureId = binnedInstance.featureIds[j]; + if (allFeatureValid || featureValid.get(featureId)) { + add( + hists, + nodeOffset + featureOffset[featureId], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); + } + } + } + } + } + + private void calcHistsForSplitNode( + LearningNode node, + int[] features, + int[] binOffsets, + double[] hists, + int nodeOffset) { + double[] totalHists = new double[4]; + calcTotalHists(node, totalHists, 0); + + int[] featureOffsets = new int[numFeatures]; + BitSet featureValid = null; + boolean allFeatureValid; + if (numFeatures != features.length) { + allFeatureValid = false; + featureValid = new BitSet(numFeatures); + for (int i = 0; i < features.length; i += 1) { + featureValid.set(features[i]); + featureOffsets[features[i]] = binOffsets[i]; + } + } else { + allFeatureValid = true; + System.arraycopy(binOffsets, 0, featureOffsets, 0, numFeatures); + } + + calcHistsForNonDefaultBins( + node, allFeatureValid, featureValid, featureOffsets, hists, nodeOffset); + + for (int featureId : features) { + calcHistsForDefaultBin( + featureDefaultVal[featureId], + featureOffsets[featureId], + numFeatureBins[featureId], + totalHists, + hists, + nodeOffset); + } + } + + /** Calculate histograms for all (nodeId, featureId) pairs. */ + private void calcHistsForPairs( + int subtaskId, Collector> out) { + long start = System.currentTimeMillis(); + int numNodes = layer.size(); + int offset = 0; + int pairBaseId = 0; + for (int k = 0; k < numNodes; k += 1) { + int[] features = nodeToFeatures[k]; + final int nodeOffset = offset; + int[] binOffsets = new int[features.length]; + for (int i = 0; i < features.length; i += 1) { + binOffsets[i] = offset - nodeOffset; + offset += numFeatureBins[features[i]]; + } + + double[] nodeHists = new double[(offset - nodeOffset) * BIN_SIZE]; + long nodeStart = System.currentTimeMillis(); + LearningNode node = layer.get(k); + if (node.depth != maxDepth) { + calcHistsForSplitNode(node, features, binOffsets, nodeHists, 0); + } else { + calcTotalHists(node, nodeHists, 0); + } + LOG.info( + "subtaskId: {}, node {}, {} #instances, {} #features, {} ms", + subtaskId, + k, + node.slice.size(), + features.length, + System.currentTimeMillis() - nodeStart); + + int sliceStart = 0; + for (int i = 0; i < features.length; i += 1) { + int sliceSize = numFeatureBins[features[i]] * BIN_SIZE; + int pairId = pairBaseId + i; + out.collect( + Tuple3.of( + subtaskId, + pairId, + new Histogram( + nodeHists, + new Slice(sliceStart, sliceStart + sliceSize)))); + sliceStart += sliceSize; + } + pairBaseId += features.length; + } + + LOG.info( + "subtaskId: {}, elapsed time for calculating histograms: {} ms", + subtaskId, + System.currentTimeMillis() - start); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java new file mode 100644 index 000000000..3cb3bd4cd --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.lossfunc.LossFunc; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.function.Consumer; + +class InstanceUpdater { + private static final Logger LOG = LoggerFactory.getLogger(InstanceUpdater.class); + + private final int subtaskId; + private final LossFunc loss; + private final double stepSize; + private final double prior; + + public InstanceUpdater(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + loss = trainContext.loss; + stepSize = trainContext.strategy.stepSize; + prior = trainContext.prior; + } + + public void update( + double[] pgh, + List leaves, + int[] indices, + BinnedInstance[] instances, + Consumer pghSetter, + List treeNodes) { + LOG.info("subtaskId: {}, {} start", subtaskId, InstanceUpdater.class.getSimpleName()); + long start = System.currentTimeMillis(); + if (pgh.length == 0) { + pgh = new double[instances.length * 3]; + for (int i = 0; i < instances.length; i += 1) { + double label = instances[i].label; + pgh[3 * i] = prior; + pgh[3 * i + 1] = loss.gradient(prior, label); + pgh[3 * i + 2] = loss.hessian(prior, label); + } + } + + for (LearningNode nodeInfo : leaves) { + Split split = treeNodes.get(nodeInfo.nodeIndex).split; + double pred = split.prediction * stepSize; + for (int i = nodeInfo.slice.start; i < nodeInfo.slice.end; ++i) { + int instanceId = indices[i]; + updatePgh(instanceId, pred, instances[instanceId].label, pgh); + } + for (int i = nodeInfo.oob.start; i < nodeInfo.oob.end; ++i) { + int instanceId = indices[i]; + updatePgh(instanceId, pred, instances[instanceId].label, pgh); + } + } + pghSetter.accept(pgh); + LOG.info("subtaskId: {}, {} end", subtaskId, InstanceUpdater.class.getSimpleName()); + LOG.info( + "subtaskId: {}, elapsed time for updating instances: {} ms", + subtaskId, + System.currentTimeMillis() - start); + } + + private void updatePgh(int instanceId, double pred, double label, double[] pgh) { + pgh[instanceId * 3] += pred; + pgh[instanceId * 3 + 1] = loss.gradient(pgh[instanceId * 3], label); + pgh[instanceId * 3 + 2] = loss.hessian(pgh[instanceId * 3], label); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java new file mode 100644 index 000000000..6e6dd071b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +class NodeSplitter { + private static final Logger LOG = LoggerFactory.getLogger(NodeSplitter.class); + + private final int subtaskId; + private final FeatureMeta[] featureMetas; + private final int maxLeaves; + private final int maxDepth; + + public NodeSplitter(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + featureMetas = trainContext.featureMetas; + maxLeaves = trainContext.strategy.maxNumLeaves; + maxDepth = trainContext.strategy.maxDepth; + } + + private int partitionInstances( + Split split, Slice slice, int[] indices, BinnedInstance[] instances) { + int lstart = slice.start; + int lend = slice.end - 1; + while (lstart <= lend) { + while (lstart <= lend && split.shouldGoLeft(instances[indices[lstart]])) { + lstart += 1; + } + while (lstart <= lend && !split.shouldGoLeft(instances[indices[lend]])) { + lend -= 1; + } + if (lstart < lend) { + int temp = indices[lstart]; + indices[lstart] = indices[lend]; + indices[lend] = temp; + } + } + return lstart; + } + + private void splitNode( + Node treeNode, + LearningNode nodeInfo, + int[] indices, + BinnedInstance[] instances, + List nextLayer, + List treeNodes) { + int mid = partitionInstances(treeNode.split, nodeInfo.slice, indices, instances); + int oobMid = partitionInstances(treeNode.split, nodeInfo.oob, indices, instances); + + treeNode.left = treeNodes.size(); + treeNodes.add(new Node()); + treeNode.right = treeNodes.size(); + treeNodes.add(new Node()); + + nextLayer.add( + new LearningNode( + treeNode.left, + new Slice(nodeInfo.slice.start, mid), + new Slice(nodeInfo.oob.start, oobMid), + nodeInfo.depth + 1)); + nextLayer.add( + new LearningNode( + treeNode.right, + new Slice(mid, nodeInfo.slice.end), + new Slice(oobMid, nodeInfo.oob.end), + nodeInfo.depth + 1)); + } + + public List split( + List treeNodes, + List layer, + List leaves, + Split[] splits, + int[] indices, + BinnedInstance[] instances) { + LOG.info("subtaskId: {}, {} start", subtaskId, NodeSplitter.class.getSimpleName()); + long start = System.currentTimeMillis(); + Preconditions.checkState(splits.length == layer.size()); + + List nextLayer = new ArrayList<>(); + + // nodes in current layer or next layer are expected to generate at least 1 leaf. + int numQueued = layer.size(); + for (int i = 0; i < layer.size(); i += 1) { + LearningNode node = layer.get(i); + Split split = splits[i]; + numQueued -= 1; + Node treeNode = treeNodes.get(node.nodeIndex); + treeNode.split = split; + if (!split.isValid() + || treeNode.isLeaf + || (leaves.size() + numQueued + 2) > maxLeaves + || node.depth + 1 > maxDepth) { + treeNode.isLeaf = true; + leaves.add(node); + } else { + splitNode(treeNode, node, indices, instances, nextLayer, treeNodes); + // Converts splits point from bin id to real feature value after splitting node. + if (split instanceof Split.ContinuousSplit) { + Split.ContinuousSplit cs = (Split.ContinuousSplit) split; + FeatureMeta.ContinuousFeatureMeta featureMeta = + (FeatureMeta.ContinuousFeatureMeta) featureMetas[cs.featureId]; + cs.threshold = featureMeta.binEdges[(int) cs.threshold + 1]; + } + numQueued += 2; + } + } + LOG.info("subtaskId: {}, {} end", subtaskId, NodeSplitter.class.getSimpleName()); + LOG.info( + "subtaskId: {}, elapsed time for splitting nodes: {} ms", + subtaskId, + System.currentTimeMillis() - start); + return nextLayer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java new file mode 100644 index 000000000..d8a0b909c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ALL_TREES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.CURRENT_TREE_NODES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.INSTANCES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LAYER; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LEAVES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NEED_INIT_TREE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.PREDS_GRADS_HESSIANS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ROOT_LEARNING_NODE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SHUFFLED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SWAPPED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; + +/** + * Post-process after global splits obtained, including split instances to left or child nodes, and + * update instances scores after a tree is complete. + */ +public class PostSplitsOperator + extends AbstractSharedObjectsOneInputStreamOperator, Integer> + implements IterationListener { + + private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; + private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; + + private static final Logger LOG = LoggerFactory.getLogger(PostSplitsOperator.class); + + // States of local data. + private transient Split[] nodeSplits; + private transient ListStateWithCache nodeSplitterState; + private transient NodeSplitter nodeSplitter; + private transient ListStateWithCache instanceUpdaterState; + private transient InstanceUpdater instanceUpdater; + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + nodeSplitterState = + new ListStateWithCache<>( + new KryoSerializer<>(NodeSplitter.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + nodeSplitter = + OperatorStateUtils.getUniqueElement(nodeSplitterState, NODE_SPLITTER_STATE_NAME) + .orElse(null); + instanceUpdaterState = + new ListStateWithCache<>( + new KryoSerializer<>(InstanceUpdater.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + instanceUpdater = + OperatorStateUtils.getUniqueElement( + instanceUpdaterState, INSTANCE_UPDATER_STATE_NAME) + .orElse(null); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + nodeSplitterState.snapshotState(context); + instanceUpdaterState.snapshotState(context); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context c, Collector collector) throws Exception { + if (0 == epochWatermark) { + TrainContext trainContext = context.read(TRAIN_CONTEXT.sameStep()); + nodeSplitter = new NodeSplitter(trainContext); + nodeSplitterState.update(Collections.singletonList(nodeSplitter)); + instanceUpdater = new InstanceUpdater(trainContext); + instanceUpdaterState.update(Collections.singletonList(instanceUpdater)); + } + + int[] indices = new int[0]; + if (epochWatermark > 0) { + indices = context.read(SWAPPED_INDICES.prevStep()); + } + if (0 == indices.length) { + indices = context.read(SHUFFLED_INDICES.sameStep()).clone(); + } + + BinnedInstance[] instances = context.read(INSTANCES.sameStep()); + List leaves = context.read(LEAVES.prevStep()); + List layer = context.read(LAYER.prevStep()); + List currentTreeNodes; + if (layer.isEmpty()) { + layer = Collections.singletonList(context.read(ROOT_LEARNING_NODE.sameStep())); + currentTreeNodes = new ArrayList<>(); + currentTreeNodes.add(new Node()); + } else { + currentTreeNodes = context.read(CURRENT_TREE_NODES.prevStep()); + } + + List nextLayer = + nodeSplitter.split(currentTreeNodes, layer, leaves, nodeSplits, indices, instances); + nodeSplits = null; + context.write(LEAVES, leaves); + context.write(LAYER, nextLayer); + context.write(CURRENT_TREE_NODES, currentTreeNodes); + + if (nextLayer.isEmpty()) { + // Current tree is finished. + context.write(NEED_INIT_TREE, true); + instanceUpdater.update( + context.read(PREDS_GRADS_HESSIANS.prevStep()), + leaves, + indices, + instances, + d -> context.write(PREDS_GRADS_HESSIANS, d), + currentTreeNodes); + leaves.clear(); + List> allTrees = context.read(ALL_TREES.prevStep()); + allTrees.add(currentTreeNodes); + + context.write(LEAVES, new ArrayList<>()); + context.write(SWAPPED_INDICES, new int[0]); + context.write(ALL_TREES, allTrees); + LOG.info("finalize {}-th tree", allTrees.size()); + } else { + context.write(SWAPPED_INDICES, indices); + context.write(NEED_INIT_TREE, false); + + context.renew(PREDS_GRADS_HESSIANS); + context.renew(ALL_TREES); + } + } + + @Override + public void onIterationTerminated(Context c, Collector collector) { + context.write(PREDS_GRADS_HESSIANS, new double[0]); + context.write(SWAPPED_INDICES, new int[0]); + context.write(LEAVES, Collections.emptyList()); + context.write(LAYER, Collections.emptyList()); + context.write(CURRENT_TREE_NODES, Collections.emptyList()); + } + + @Override + public void processElement(StreamRecord> element) throws Exception { + if (null == nodeSplits) { + List layer = context.read(LAYER.sameStep()); + int numNodes = (layer.isEmpty()) ? 1 : layer.size(); + nodeSplits = new Split[numNodes]; + } + Tuple2 value = element.getValue(); + int nodeId = value.f0; + Split split = value.f1; + LOG.debug("Received split for node {}", nodeId); + nodeSplits[nodeId] = split; + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(LAYER.sameStep()); + } + + @Override + public void close() throws Exception { + nodeSplitterState.clear(); + instanceUpdaterState.clear(); + super.close(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java new file mode 100644 index 000000000..c3a0617bf --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.BitSet; +import java.util.HashMap; +import java.util.Map; + +/** + * This operator reduces histograms for (nodeId, featureId) pairs. + * + *

The input elements are tuples of (subtask index, (nodeId, featureId) pair index, Histogram). + * The output elements are tuples of ((nodeId, featureId) pair index, Histogram). + */ +public class ReduceHistogramFunction + extends RichFlatMapFunction< + Tuple3, Tuple2> { + + private static final Logger LOG = LoggerFactory.getLogger(ReduceHistogramFunction.class); + + private final Map pairAccepted = new HashMap<>(); + private final Map pairHistogram = new HashMap<>(); + private int numSubtasks; + + @Override + public void open(Configuration parameters) throws Exception { + numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + } + + @Override + public void flatMap( + Tuple3 value, Collector> out) + throws Exception { + int sourceSubtaskId = value.f0; + int pairId = value.f1; + Histogram histogram = value.f2; + + BitSet accepted = pairAccepted.getOrDefault(pairId, new BitSet(numSubtasks)); + if (accepted.isEmpty()) { + LOG.debug("Received histogram for new pair {}", pairId); + } + Preconditions.checkState(!accepted.get(sourceSubtaskId)); + accepted.set(sourceSubtaskId); + pairAccepted.put(pairId, accepted); + + pairHistogram.compute(pairId, (k, v) -> null == v ? histogram : v.accumulate(histogram)); + if (numSubtasks == accepted.cardinality()) { + out.collect(Tuple2.of(pairId, pairHistogram.get(pairId))); + LOG.debug("Output accumulated histogram for pair {}", pairId); + pairAccepted.remove(pairId); + pairHistogram.remove(pairId); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java new file mode 100644 index 000000000..289b68c8a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.BitSet; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NODE_FEATURE_PAIRS; + +/** + * Reduces best splits for nodes. + * + *

The input elements are tuples of (node index, (nodeId, featureId) pair index, Split). The + * output elements are tuples of (node index, Split). + */ +public class ReduceSplitsOperator + extends AbstractSharedObjectsOneInputStreamOperator< + Tuple3, Tuple2> { + + private static final Logger LOG = LoggerFactory.getLogger(ReduceSplitsOperator.class); + + private Map nodeFeatureMap; + private Map nodeBestSplit; + private Map nodeFeatureCounter; + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + nodeFeatureMap = new HashMap<>(); + nodeBestSplit = new HashMap<>(); + nodeFeatureCounter = new HashMap<>(); + } + + @Override + public void processElement(StreamRecord> element) + throws Exception { + if (nodeFeatureMap.isEmpty()) { + Preconditions.checkState(nodeBestSplit.isEmpty()); + nodeFeatureCounter.clear(); + int[] nodeFeaturePairs = context.read(NODE_FEATURE_PAIRS.nextStep()); + for (int i = 0; i < nodeFeaturePairs.length / 2; i += 1) { + int nodeId = nodeFeaturePairs[2 * i]; + nodeFeatureCounter.compute(nodeId, (k, v) -> null == v ? 1 : v + 1); + } + } + + Tuple3 value = element.getValue(); + int nodeId = value.f0; + int pairId = value.f1; + Split split = value.f2; + BitSet featureMap = nodeFeatureMap.getOrDefault(nodeId, new BitSet()); + if (featureMap.isEmpty()) { + LOG.debug("Received split for new node {}", nodeId); + } + int[] nodeFeaturePairs = context.read(NODE_FEATURE_PAIRS.nextStep()); + Preconditions.checkState(nodeId == nodeFeaturePairs[pairId * 2]); + int featureId = nodeFeaturePairs[pairId * 2 + 1]; + Preconditions.checkState(!featureMap.get(featureId)); + featureMap.set(featureId); + nodeFeatureMap.put(nodeId, featureMap); + + nodeBestSplit.compute(nodeId, (k, v) -> null == v ? split : v.accumulate(split)); + if (featureMap.cardinality() == nodeFeatureCounter.get(nodeId)) { + output.collect(new StreamRecord<>(Tuple2.of(nodeId, nodeBestSplit.get(nodeId)))); + LOG.debug("Output accumulated split for node {}", nodeId); + nodeBestSplit.remove(nodeId); + nodeFeatureMap.remove(nodeId); + nodeFeatureCounter.remove(nodeId); + } + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(NODE_FEATURE_PAIRS.nextStep()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java new file mode 100644 index 000000000..1359fb4a8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; +import org.apache.flink.ml.common.gbt.GBTRunner; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; +import org.apache.flink.ml.common.sharedobjects.Descriptor; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsUtils; +import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Stores constants used for {@link SharedObjectsUtils} in {@link GBTRunner}. + * + *

In the iteration, some data needs to be shared and accessed between subtasks of different + * operators within one JVM to reduce memory footprint and communication cost. We use {@link + * SharedObjectsUtils} with co-location mechanism to achieve such purpose. + * + *

All shared objects have corresponding {@link Descriptor}s, and can be read/written through + * {@link Descriptor}s from different operator subtasks. Note that every shared object has an owner, + * and the owner can set new values and snapshot the object. + * + *

This class records all {@link Descriptor}s used in {@link GBTRunner} and their owners. + */ +@Internal +public class SharedObjectsConstants { + + /** Instances (after binned). */ + static final Descriptor INSTANCES = + Descriptor.of( + "instances", + new GenericArraySerializer<>( + BinnedInstance.class, BinnedInstanceSerializer.INSTANCE)); + + /** + * (prediction, gradient, and hessian) of instances, sharing same indexing with {@link + * #INSTANCES}. + */ + static final Descriptor PREDS_GRADS_HESSIANS = + Descriptor.of( + "preds_grads_hessians", + new OptimizedDoublePrimitiveArraySerializer(), + new double[0]); + + /** Shuffle indices of instances used after every new tree just initialized. */ + static final Descriptor SHUFFLED_INDICES = + Descriptor.of("shuffled_indices", IntPrimitiveArraySerializer.INSTANCE); + + /** Swapped indices of instances used when {@link #SHUFFLED_INDICES} not applicable. */ + static final Descriptor SWAPPED_INDICES = + Descriptor.of("swapped_indices", IntPrimitiveArraySerializer.INSTANCE); + + /** (nodeId, featureId) pairs used to calculate histograms. */ + static final Descriptor NODE_FEATURE_PAIRS = + Descriptor.of("node_feature_pairs", IntPrimitiveArraySerializer.INSTANCE); + + /** Leaves nodes of current working tree. */ + static final Descriptor> LEAVES = + Descriptor.of( + "leaves", + new ListSerializer<>(LearningNodeSerializer.INSTANCE), + new ArrayList<>()); + + /** Nodes in current layer of current working tree. */ + static final Descriptor> LAYER = + Descriptor.of( + "layer", + new ListSerializer<>(LearningNodeSerializer.INSTANCE), + new ArrayList<>()); + + /** The root node when initializing a new tree. */ + static final Descriptor ROOT_LEARNING_NODE = + Descriptor.of("root_learning_node", LearningNodeSerializer.INSTANCE); + + /** All finished trees. */ + static final Descriptor>> ALL_TREES = + Descriptor.of( + "all_trees", + new ListSerializer<>(new ListSerializer<>(NodeSerializer.INSTANCE)), + new ArrayList<>()); + + /** Nodes in current working tree. */ + static final Descriptor> CURRENT_TREE_NODES = + Descriptor.of("current_tree_nodes", new ListSerializer<>(NodeSerializer.INSTANCE)); + + /** Indicates the necessity of initializing a new tree. */ + static final Descriptor NEED_INIT_TREE = + Descriptor.of("need_init_tree", BooleanSerializer.INSTANCE, true); + + /** Data items owned by the `PostSplits` operator. */ + public static final List> OWNED_BY_POST_SPLITS_OP = + Arrays.asList( + PREDS_GRADS_HESSIANS, + SWAPPED_INDICES, + LEAVES, + LAYER, + ALL_TREES, + CURRENT_TREE_NODES, + NEED_INIT_TREE); + + /** Indicate a new tree has been initialized. */ + static final Descriptor HAS_INITED_TREE = + Descriptor.of("has_inited_tree", BooleanSerializer.INSTANCE, false); + + /** Training context. */ + static final Descriptor TRAIN_CONTEXT = + Descriptor.of( + "train_context", + new KryoSerializer<>(TrainContext.class, new ExecutionConfig()), + new TrainContext()); + + /** Data items owned by the `CacheDataCalcLocalHists` operator. */ + public static final List> OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP = + Arrays.asList( + INSTANCES, + SHUFFLED_INDICES, + NODE_FEATURE_PAIRS, + ROOT_LEARNING_NODE, + HAS_INITED_TREE, + TRAIN_CONTEXT); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java new file mode 100644 index 000000000..458ac0b9c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.DataUtils; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.splitter.CategoricalFeatureSplitter; +import org.apache.flink.ml.common.gbt.splitter.ContinuousFeatureSplitter; +import org.apache.flink.ml.common.gbt.splitter.HistogramFeatureSplitter; +import org.apache.flink.util.Preconditions; + +class SplitFinder { + private final HistogramFeatureSplitter[] splitters; + private final int maxDepth; + private final int maxNumLeaves; + + public SplitFinder(TrainContext trainContext) { + FeatureMeta[] featureMetas = trainContext.featureMetas; + int numFeatures = trainContext.numFeatures; + splitters = new HistogramFeatureSplitter[numFeatures + 1]; + for (int i = 0; i < numFeatures; ++i) { + splitters[i] = + FeatureMeta.Type.CATEGORICAL == featureMetas[i].type + ? new CategoricalFeatureSplitter( + i, featureMetas[i], trainContext.strategy) + : new ContinuousFeatureSplitter( + i, featureMetas[i], trainContext.strategy); + } + // Adds an addition splitter to obtain the prediction of the node. + splitters[numFeatures] = + new ContinuousFeatureSplitter( + numFeatures, + new FeatureMeta.ContinuousFeatureMeta("SPECIAL", 0, new double[0]), + trainContext.strategy); + maxDepth = trainContext.strategy.maxDepth; + maxNumLeaves = trainContext.strategy.maxNumLeaves; + } + + public Split calc(LearningNode node, int featureId, int numLeaves, Histogram histogram) { + Preconditions.checkState(node.depth < maxDepth || numLeaves + 2 <= maxNumLeaves); + Preconditions.checkState(histogram.slice.start == 0); + splitters[featureId].reset( + histogram.hists, new Slice(0, histogram.hists.length / DataUtils.BIN_SIZE)); + return splitters[featureId].bestSplit(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java new file mode 100644 index 000000000..3d0bc921a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ALL_TREES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; + +/** Determines whether to terminated training. */ +public class TerminationOperator + extends AbstractSharedObjectsOneInputStreamOperator + implements IterationListener { + + private final OutputTag modelDataOutputTag; + + public TerminationOperator(OutputTag modelDataOutputTag) { + this.modelDataOutputTag = modelDataOutputTag; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + } + + @Override + public void processElement(StreamRecord element) throws Exception {} + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context c, Collector collector) { + boolean terminated = + context.read(ALL_TREES.sameStep()).size() + == context.read(TRAIN_CONTEXT.sameStep()).strategy.maxIter; + // TODO: Add validation error rate + if (!terminated) { + output.collect(new StreamRecord<>(0)); + } + } + + @Override + public void onIterationTerminated(Context c, Collector collector) { + if (0 == getRuntimeContext().getIndexOfThisSubtask()) { + c.output( + modelDataOutputTag, + GBTModelData.from( + context.read(TRAIN_CONTEXT.prevStep()), + context.read(ALL_TREES.prevStep()))); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java new file mode 100644 index 000000000..7c124b68e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.lossfunc.LogLoss; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.lossfunc.SquaredErrorLoss; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.function.Function; + +class TrainContextInitializer { + private static final Logger LOG = LoggerFactory.getLogger(TrainContextInitializer.class); + private final BoostingStrategy strategy; + + public TrainContextInitializer(BoostingStrategy strategy) { + this.strategy = strategy; + } + + /** + * Initializes local state. + * + *

Note that local state already has some properties set in advance, see GBTRunner#boost. + */ + public TrainContext init( + TrainContext trainContext, int subtaskId, int numSubtasks, BinnedInstance[] instances) { + LOG.info( + "subtaskId: {}, {} start", + subtaskId, + TrainContextInitializer.class.getSimpleName()); + + trainContext.subtaskId = subtaskId; + trainContext.numSubtasks = numSubtasks; + + int numInstances = instances.length; + int numFeatures = trainContext.featureMetas.length; + + LOG.info( + "subtaskId: {}, #samples: {}, #features: {}", subtaskId, numInstances, numFeatures); + + trainContext.strategy = strategy; + trainContext.numInstances = numInstances; + trainContext.numFeatures = numFeatures; + + trainContext.numBaggingInstances = getNumBaggingSamples(numInstances); + trainContext.numBaggingFeatures = getNumBaggingFeatures(numFeatures); + + trainContext.instanceRandomizer = new Random(subtaskId + strategy.seed); + trainContext.featureRandomizer = new Random(strategy.seed); + + trainContext.loss = getLoss(); + trainContext.prior = calcPrior(trainContext.labelSumCount); + + // A special `feature` is appended with #bins = 1 to simplify codes. + trainContext.numFeatureBins = + ArrayUtils.add( + Arrays.stream(trainContext.featureMetas) + .mapToInt(d -> d.numBins(trainContext.strategy.useMissing)) + .toArray(), + 1); + LOG.debug("Number of bins for each feature: {}", trainContext.numFeatureBins); + LOG.info("subtaskId: {}, {} end", subtaskId, TrainContextInitializer.class.getSimpleName()); + return trainContext; + } + + private int getNumBaggingSamples(int numSamples) { + return (int) Math.min(numSamples, Math.ceil(numSamples * strategy.subsamplingRate)); + } + + private int getNumBaggingFeatures(int numFeatures) { + final List supported = Arrays.asList("auto", "all", "onethird", "sqrt", "log2"); + final String errorMsg = + String.format( + "Parameter `featureSubsetStrategy` supports %s, (0.0 - 1.0], [1 - n].", + String.join(", ", supported)); + final Function clamp = + d -> Math.max(1, Math.min(d.intValue(), numFeatures)); + String featureSubsetStrategy = strategy.featureSubsetStrategy; + try { + int numBaggingFeatures = Integer.parseInt(featureSubsetStrategy); + Preconditions.checkArgument( + numBaggingFeatures >= 1 && numBaggingFeatures <= numFeatures, errorMsg); + } catch (NumberFormatException ignored) { + } + try { + double baggingRatio = Double.parseDouble(featureSubsetStrategy); + Preconditions.checkArgument(baggingRatio > 0. && baggingRatio <= 1., errorMsg); + return clamp.apply(baggingRatio * numFeatures); + } catch (NumberFormatException ignored) { + } + + Preconditions.checkArgument(supported.contains(featureSubsetStrategy), errorMsg); + switch (featureSubsetStrategy) { + case "auto": + return TaskType.CLASSIFICATION.equals(strategy.taskType) + ? clamp.apply(Math.sqrt(numFeatures)) + : clamp.apply(numFeatures / 3.); + case "all": + return numFeatures; + case "onethird": + return clamp.apply(numFeatures / 3.); + case "sqrt": + return clamp.apply(Math.sqrt(numFeatures)); + case "log2": + return clamp.apply(Math.log(numFeatures) / Math.log(2)); + default: + throw new IllegalArgumentException(errorMsg); + } + } + + private LossFunc getLoss() { + switch (strategy.lossType) { + case LOGISTIC: + return LogLoss.INSTANCE; + case SQUARED: + return SquaredErrorLoss.INSTANCE; + default: + throw new UnsupportedOperationException("Unsupported loss."); + } + } + + private double calcPrior(Tuple2 labelStat) { + switch (strategy.lossType) { + case LOGISTIC: + return Math.log(labelStat.f0 / (labelStat.f1 - labelStat.f0)); + case SQUARED: + return labelStat.f0 / labelStat.f1; + default: + throw new UnsupportedOperationException("Unsupported loss."); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java new file mode 100644 index 000000000..4cb9090a1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.DataUtils; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.TrainContext; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Random; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +class TreeInitializer { + private static final Logger LOG = LoggerFactory.getLogger(TreeInitializer.class); + + private final int subtaskId; + private final int numInstances; + private final int numBaggingInstances; + private final int[] shuffledIndices; + private final Random instanceRandomizer; + + public TreeInitializer(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + numInstances = trainContext.numInstances; + numBaggingInstances = trainContext.numBaggingInstances; + instanceRandomizer = trainContext.instanceRandomizer; + shuffledIndices = IntStream.range(0, numInstances).toArray(); + } + + /** Calculate local histograms for nodes in current layer of tree. */ + public void init(int numTrees, Consumer shuffledIndicesSetter) { + LOG.info("subtaskId: {}, {} start", subtaskId, TreeInitializer.class.getSimpleName()); + // Initializes the root node of a new tree when last tree is finalized. + DataUtils.shuffle(shuffledIndices, instanceRandomizer); + shuffledIndicesSetter.accept(shuffledIndices); + LOG.info("subtaskId: {}, initialize {}-th tree", subtaskId, numTrees + 1); + LOG.info("subtaskId: {}, {} end", this.subtaskId, TreeInitializer.class.getSimpleName()); + } + + public LearningNode getRootLearningNode() { + return new LearningNode( + 0, + new Slice(0, numBaggingInstances), + new Slice(numBaggingInstances, numInstances), + 1); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java new file mode 100644 index 000000000..ee0d5aa5e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.HessianImpurity; +import org.apache.flink.ml.common.gbt.defs.Split; + +import org.eclipse.collections.api.list.primitive.MutableIntList; +import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; + +import java.util.BitSet; + +import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; + +/** Splitter for a categorical feature using LightGBM many-vs-many split. */ +public class CategoricalFeatureSplitter extends HistogramFeatureSplitter { + + public CategoricalFeatureSplitter( + int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + super(featureId, featureMeta, strategy); + } + + @Override + public Split.CategoricalSplit bestSplit() { + HessianImpurity total = emptyImpurity(); + HessianImpurity missing = emptyImpurity(); + countTotalMissing(total, missing); + + if (total.getNumInstances() <= minSamplesPerLeaf) { + return Split.CategoricalSplit.invalid(total.prediction()); + } + + int numBins = slice.size(); + // Sorts categories (bins) based on grads / hessians, i.e., LightGBM many-vs-many approach. + MutableIntList sortedIndices = new IntArrayList(numBins); + // A category (bin) is treated as missing values if its occurrences is smaller than a + // threshold. Currently, the threshold is 0. + BitSet ignoredIndices = new BitSet(numBins); + { + double[] scores = new double[numBins]; + for (int i = 0; i < numBins; ++i) { + int index = (slice.start + i) * BIN_SIZE; + if (hists[index + 3] > 0) { + sortedIndices.add(i); + scores[i] = hists[index] / hists[index + 1]; + } else { + ignoredIndices.set(i); + missing.add( + (int) hists[index + 3], + hists[index + 2], + hists[index], + hists[index + 1]); + } + } + sortedIndices.sortThis( + (value1, value2) -> Double.compare(scores[value1], scores[value2])); + } + + Tuple3 bestSplit = + findBestSplit(sortedIndices.toArray(), total, missing); + double bestGain = bestSplit.f0; + int bestSplitIndex = bestSplit.f1; + boolean missingGoLeft = bestSplit.f2; + + if (bestGain <= Split.INVALID_GAIN || bestGain <= minInfoGain) { + return Split.CategoricalSplit.invalid(total.prediction()); + } + + // Indicates which bins should go left. + BitSet binsGoLeft = new BitSet(numBins); + if (useMissing) { + for (int i = 0; i < sortedIndices.size(); ++i) { + int binId = sortedIndices.get(i); + if (i <= bestSplitIndex) { + if (binId < featureMeta.missingBin) { + binsGoLeft.set(binId); + } else if (binId > featureMeta.missingBin) { + binsGoLeft.set(binId - 1); + } + } + } + } else { + for (int i = 0; i < sortedIndices.size(); i += 1) { + int binId = sortedIndices.get(i); + if (i <= bestSplitIndex) { + binsGoLeft.set(binId); + } + } + } + if (missingGoLeft) { + binsGoLeft.or(ignoredIndices); + } + return new Split.CategoricalSplit( + featureId, + bestGain, + featureMeta.missingBin, + missingGoLeft, + total.prediction(), + binsGoLeft); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java new file mode 100644 index 000000000..bee536691 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.HessianImpurity; +import org.apache.flink.ml.common.gbt.defs.Split; + +/** Splitter for a continuous feature. */ +public final class ContinuousFeatureSplitter extends HistogramFeatureSplitter { + + public ContinuousFeatureSplitter( + int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + super(featureId, featureMeta, strategy); + } + + @Override + public Split.ContinuousSplit bestSplit() { + HessianImpurity total = emptyImpurity(); + HessianImpurity missing = emptyImpurity(); + countTotalMissing(total, missing); + + if (total.getNumInstances() <= minSamplesPerLeaf) { + return Split.ContinuousSplit.invalid(total.prediction()); + } + + Tuple3 bestSplit = findBestSplit(slice.size(), total, missing); + double bestGain = bestSplit.f0; + int bestSplitBinId = bestSplit.f1; + boolean missingGoLeft = bestSplit.f2; + + if (bestGain <= Split.INVALID_GAIN || bestGain <= minInfoGain) { + return Split.ContinuousSplit.invalid(total.prediction()); + } + int splitPoint = + useMissing && bestSplitBinId > featureMeta.missingBin + ? bestSplitBinId - 1 + : bestSplitBinId; + return new Split.ContinuousSplit( + featureId, + bestGain, + featureMeta.missingBin, + missingGoLeft, + total.prediction(), + splitPoint, + !strategy.isInputVector, + ((FeatureMeta.ContinuousFeatureMeta) featureMeta).zeroBin); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java new file mode 100644 index 000000000..702bd2189 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.Split; + +/** + * Tests if the node can be split on a given feature and obtains best split. + * + *

When testing the node, we only check internal criteria, such as minimum info gain, minium + * samples per leaf, etc. The external criteria, like maximum depth or maximum number of leaves are + * not checked. + */ +public abstract class FeatureSplitter { + protected final int featureId; + protected final FeatureMeta featureMeta; + protected final BoostingStrategy strategy; + + protected final int minSamplesPerLeaf; + protected final double minSampleRatioPerChild; + protected final double minInfoGain; + + public FeatureSplitter(int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + this.strategy = strategy; + this.featureId = featureId; + this.featureMeta = featureMeta; + + this.minSamplesPerLeaf = strategy.minInstancesPerNode; + // TODO: not exactly the same since weights are not supported right now. + this.minSampleRatioPerChild = strategy.minWeightFractionPerNode; + this.minInfoGain = strategy.minInfoGain; + } + + public abstract Split bestSplit(); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java new file mode 100644 index 000000000..8aa3a5f69 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.HessianImpurity; +import org.apache.flink.ml.common.gbt.defs.Impurity; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.Split; + +import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; + +/** Histogram based feature splitter. */ +public abstract class HistogramFeatureSplitter extends FeatureSplitter { + protected final boolean useMissing; + protected double[] hists; + protected Slice slice; + + public HistogramFeatureSplitter( + int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + super(featureId, featureMeta, strategy); + this.useMissing = strategy.useMissing; + } + + protected boolean isSplitIllegal(Impurity total, Impurity left, Impurity right) { + return (minSamplesPerLeaf > left.getTotalWeights() + || minSamplesPerLeaf > right.getTotalWeights()) + || minSampleRatioPerChild > 1. * left.getNumInstances() / total.getNumInstances() + || minSampleRatioPerChild > 1. * right.getNumInstances() / total.getNumInstances(); + } + + protected double gain(Impurity total, Impurity left, Impurity right) { + return isSplitIllegal(total, left, right) ? Split.INVALID_GAIN : total.gain(left, right); + } + + protected void addBinToLeft(int binId, HessianImpurity left, HessianImpurity right) { + int index = (slice.start + binId) * BIN_SIZE; + left.add((int) hists[index + 3], hists[index + 2], hists[index], hists[index + 1]); + if (null != right) { + right.subtract( + (int) hists[index + 3], hists[index + 2], hists[index], hists[index + 1]); + } + } + + protected Tuple2 findBestSplitWithInitial( + int[] sortedBinIds, + HessianImpurity total, + HessianImpurity left, + HessianImpurity right) { + // Bins [0, bestSplitBinId] go left. + int bestSplitBinId = 0; + double bestGain = Split.INVALID_GAIN; + for (int i = 0; i < sortedBinIds.length; i += 1) { + int binId = sortedBinIds[i]; + if (useMissing && binId == featureMeta.missingBin) { + continue; + } + addBinToLeft(binId, left, right); + double gain = gain(total, left, right); + if (gain > bestGain && gain >= minInfoGain) { + bestGain = gain; + bestSplitBinId = i; + } + } + return Tuple2.of(bestGain, bestSplitBinId); + } + + protected Tuple2 findBestSplitWithInitial( + int numBins, HessianImpurity total, HessianImpurity left, HessianImpurity right) { + // Bins [0, bestSplitBinId] go left. + int bestSplitBinId = 0; + double bestGain = Split.INVALID_GAIN; + for (int binId = 0; binId < numBins; binId += 1) { + if (useMissing && binId == featureMeta.missingBin) { + continue; + } + addBinToLeft(binId, left, right); + double gain = gain(total, left, right); + if (gain > bestGain && gain >= minInfoGain) { + bestGain = gain; + bestSplitBinId = binId; + } + } + return Tuple2.of(bestGain, bestSplitBinId); + } + + protected Tuple3 findBestSplit( + int[] sortedBinIds, HessianImpurity total, HessianImpurity missing) { + double bestGain = Split.INVALID_GAIN; + int bestSplitBinId = 0; + boolean missingGoLeft = false; + + { + // The cases where the missing values go right, or missing values are not allowed. + HessianImpurity left = emptyImpurity(); + HessianImpurity right = (HessianImpurity) total.clone(); + Tuple2 bestSplit = + findBestSplitWithInitial(sortedBinIds, total, left, right); + if (bestSplit.f0 > bestGain) { + bestGain = bestSplit.f0; + bestSplitBinId = bestSplit.f1; + } + } + + if (useMissing && missing.getNumInstances() > 0) { + // The cases where the missing values go left. + HessianImpurity leftWithMissing = emptyImpurity().add(missing); + HessianImpurity rightWithoutMissing = (HessianImpurity) total.clone().subtract(missing); + Tuple2 bestSplitMissingGoLeft = + findBestSplitWithInitial( + sortedBinIds, total, leftWithMissing, rightWithoutMissing); + if (bestSplitMissingGoLeft.f0 > bestGain) { + bestGain = bestSplitMissingGoLeft.f0; + bestSplitBinId = bestSplitMissingGoLeft.f1; + missingGoLeft = true; + } + } + return Tuple3.of(bestGain, bestSplitBinId, missingGoLeft); + } + + protected Tuple3 findBestSplit( + int numBins, HessianImpurity total, HessianImpurity missing) { + double bestGain = Split.INVALID_GAIN; + int bestSplitBinId = 0; + boolean missingGoLeft = false; + + { + // The cases where the missing values go right, or missing values are not allowed. + HessianImpurity left = emptyImpurity(); + HessianImpurity right = (HessianImpurity) total.clone(); + Tuple2 bestSplit = + findBestSplitWithInitial(numBins, total, left, right); + if (bestSplit.f0 > bestGain) { + bestGain = bestSplit.f0; + bestSplitBinId = bestSplit.f1; + } + } + + if (useMissing) { + // The cases where the missing values go left. + HessianImpurity leftWithMissing = emptyImpurity().add(missing); + HessianImpurity rightWithoutMissing = (HessianImpurity) total.clone().subtract(missing); + Tuple2 bestSplitMissingGoLeft = + findBestSplitWithInitial(numBins, total, leftWithMissing, rightWithoutMissing); + if (bestSplitMissingGoLeft.f0 > bestGain) { + bestGain = bestSplitMissingGoLeft.f0; + bestSplitBinId = bestSplitMissingGoLeft.f1; + missingGoLeft = true; + } + } + return Tuple3.of(bestGain, bestSplitBinId, missingGoLeft); + } + + public void reset(double[] hists, Slice slice) { + this.hists = hists; + this.slice = slice; + } + + protected void countTotalMissing(HessianImpurity total, HessianImpurity missing) { + for (int i = 0; i < slice.size(); ++i) { + addBinToLeft(i, total, null); + } + if (useMissing) { + addBinToLeft(featureMeta.missingBin, missing, null); + } + } + + protected HessianImpurity emptyImpurity() { + return new HessianImpurity(strategy.regLambda, strategy.regGamma, 0, 0, 0, 0); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java new file mode 100644 index 000000000..f3e038143 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; + +import java.io.IOException; + +/** Serializer for {@link BinnedInstance}. */ +public final class BinnedInstanceSerializer extends TypeSerializerSingleton { + + public static final BinnedInstanceSerializer INSTANCE = new BinnedInstanceSerializer(); + private static final long serialVersionUID = 1L; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public BinnedInstance createInstance() { + return new BinnedInstance(); + } + + @Override + public BinnedInstance copy(BinnedInstance from) { + BinnedInstance instance = new BinnedInstance(); + instance.featureIds = null == from.featureIds ? null : from.featureIds.clone(); + instance.featureValues = from.featureValues.clone(); + instance.label = from.label; + instance.weight = from.weight; + return instance; + } + + @Override + public BinnedInstance copy(BinnedInstance from, BinnedInstance reuse) { + assert from.getClass() == reuse.getClass(); + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(BinnedInstance record, DataOutputView target) throws IOException { + if (null == record.featureIds) { + target.writeBoolean(true); + } else { + target.writeBoolean(false); + IntPrimitiveArraySerializer.INSTANCE.serialize(record.featureIds, target); + } + IntPrimitiveArraySerializer.INSTANCE.serialize(record.featureValues, target); + DoubleSerializer.INSTANCE.serialize(record.label, target); + DoubleSerializer.INSTANCE.serialize(record.weight, target); + } + + @Override + public BinnedInstance deserialize(DataInputView source) throws IOException { + BinnedInstance instance = new BinnedInstance(); + if (source.readBoolean()) { + instance.featureIds = null; + } else { + instance.featureIds = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + } + instance.featureValues = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + instance.label = DoubleSerializer.INSTANCE.deserialize(source); + instance.weight = DoubleSerializer.INSTANCE.deserialize(source); + return instance; + } + + @Override + public BinnedInstance deserialize(BinnedInstance reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new BinnedInstanceSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class BinnedInstanceSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public BinnedInstanceSerializerSnapshot() { + super(BinnedInstanceSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java new file mode 100644 index 000000000..13889a186 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.io.IOException; +import java.util.BitSet; + +/** Specialized serializer for {@link Split.CategoricalSplit}. */ +public final class CategoricalSplitSerializer + extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + public static final CategoricalSplitSerializer INSTANCE = new CategoricalSplitSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Split.CategoricalSplit createInstance() { + return new Split.CategoricalSplit(-1, Split.INVALID_GAIN, 0, false, 0., new BitSet()); + } + + @Override + public Split.CategoricalSplit copy(Split.CategoricalSplit from) { + return new Split.CategoricalSplit( + from.featureId, + from.gain, + from.missingBin, + from.missingGoLeft, + from.prediction, + from.categoriesGoLeft); + } + + @Override + public Split.CategoricalSplit copy(Split.CategoricalSplit from, Split.CategoricalSplit reuse) { + assert from.getClass() == reuse.getClass(); + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Split.CategoricalSplit record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.featureId, target); + DoubleSerializer.INSTANCE.serialize(record.gain, target); + IntSerializer.INSTANCE.serialize(record.missingBin, target); + BooleanSerializer.INSTANCE.serialize(record.missingGoLeft, target); + DoubleSerializer.INSTANCE.serialize(record.prediction, target); + BytePrimitiveArraySerializer.INSTANCE.serialize( + record.categoriesGoLeft.toByteArray(), target); + } + + @Override + public Split.CategoricalSplit deserialize(DataInputView source) throws IOException { + return new Split.CategoricalSplit( + IntSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + IntSerializer.INSTANCE.deserialize(source), + BooleanSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + BitSet.valueOf(BytePrimitiveArraySerializer.INSTANCE.deserialize(source))); + } + + @Override + public Split.CategoricalSplit deserialize(Split.CategoricalSplit reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new CategoricalSplitSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class CategoricalSplitSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public CategoricalSplitSerializerSnapshot() { + super(CategoricalSplitSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java new file mode 100644 index 000000000..3e85a9354 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.io.IOException; + +/** Specialized serializer for {@link Split.ContinuousSplit}. */ +public final class ContinuousSplitSerializer + extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + public static final ContinuousSplitSerializer INSTANCE = new ContinuousSplitSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Split.ContinuousSplit createInstance() { + return new Split.ContinuousSplit(-1, Split.INVALID_GAIN, 0, false, 0., 0., false, 0); + } + + @Override + public Split.ContinuousSplit copy(Split.ContinuousSplit from) { + return new Split.ContinuousSplit( + from.featureId, + from.gain, + from.missingBin, + from.missingGoLeft, + from.prediction, + from.threshold, + from.isUnseenMissing, + from.zeroBin); + } + + @Override + public Split.ContinuousSplit copy(Split.ContinuousSplit from, Split.ContinuousSplit reuse) { + assert from.getClass() == reuse.getClass(); + return copy(from); + } + + @Override + public int getLength() { + return 3 * IntSerializer.INSTANCE.getLength() + + 3 * DoubleSerializer.INSTANCE.getLength() + + 2 * BooleanSerializer.INSTANCE.getLength(); + } + + @Override + public void serialize(Split.ContinuousSplit record, DataOutputView target) throws IOException { + target.writeInt(record.featureId); + target.writeDouble(record.gain); + target.writeInt(record.missingBin); + target.writeBoolean(record.missingGoLeft); + target.writeDouble(record.prediction); + target.writeDouble(record.threshold); + target.writeBoolean(record.isUnseenMissing); + target.writeInt(record.zeroBin); + } + + @Override + public Split.ContinuousSplit deserialize(DataInputView source) throws IOException { + return new Split.ContinuousSplit( + source.readInt(), + source.readDouble(), + source.readInt(), + source.readBoolean(), + source.readDouble(), + source.readDouble(), + source.readBoolean(), + source.readInt()); + } + + @Override + public Split.ContinuousSplit deserialize(Split.ContinuousSplit reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new ContinuousSplitSplitSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class ContinuousSplitSplitSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public ContinuousSplitSplitSerializerSnapshot() { + super(ContinuousSplitSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java new file mode 100644 index 000000000..6c815254f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; +import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.defs.Node; + +import org.eclipse.collections.impl.map.mutable.primitive.IntObjectHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; + +/** Specialized serializer for {@link GBTModelData}. */ +public final class GBTModelDataSerializer extends TypeSerializerSingleton { + + public static final GBTModelDataSerializer INSTANCE = new GBTModelDataSerializer(); + private static final long serialVersionUID = 1L; + private static final NodeSerializer NODE_SERIALIZER = NodeSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public GBTModelData createInstance() { + return new GBTModelData(); + } + + @Override + public GBTModelData copy(GBTModelData from) { + GBTModelData record = new GBTModelData(); + record.type = from.type; + record.isInputVector = from.isInputVector; + + record.prior = from.prior; + record.stepSize = from.stepSize; + + record.allTrees = new ArrayList<>(from.allTrees.size()); + for (int i = 0; i < from.allTrees.size(); i += 1) { + record.allTrees.add(new ArrayList<>(from.allTrees.get(i))); + } + record.featureNames = new ArrayList<>(from.featureNames); + record.categoryToIdMaps = new IntObjectHashMap<>(from.categoryToIdMaps); + record.featureIdToBinEdges = new IntObjectHashMap<>(from.featureIdToBinEdges); + record.isCategorical = BitSet.valueOf(from.isCategorical.toByteArray()); + return record; + } + + @Override + public GBTModelData copy(GBTModelData from, GBTModelData reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(GBTModelData record, DataOutputView target) throws IOException { + StringSerializer.INSTANCE.serialize(record.type, target); + BooleanSerializer.INSTANCE.serialize(record.isInputVector, target); + + DoubleSerializer.INSTANCE.serialize(record.prior, target); + DoubleSerializer.INSTANCE.serialize(record.stepSize, target); + + IntSerializer.INSTANCE.serialize(record.allTrees.size(), target); + for (List treeNodes : record.allTrees) { + IntSerializer.INSTANCE.serialize(treeNodes.size(), target); + for (Node treeNode : treeNodes) { + NodeSerializer.INSTANCE.serialize(treeNode, target); + } + } + + IntSerializer.INSTANCE.serialize(record.featureNames.size(), target); + for (int i = 0; i < record.featureNames.size(); i += 1) { + StringSerializer.INSTANCE.serialize(record.featureNames.get(i), target); + } + + IntSerializer.INSTANCE.serialize(record.categoryToIdMaps.size(), target); + for (int featureId : record.categoryToIdMaps.keysView().toArray()) { + ObjectIntHashMap categoryToIdMap = record.categoryToIdMaps.get(featureId); + IntSerializer.INSTANCE.serialize(featureId, target); + IntSerializer.INSTANCE.serialize(categoryToIdMap.size(), target); + for (String category : categoryToIdMap.keysView()) { + StringSerializer.INSTANCE.serialize(category, target); + IntSerializer.INSTANCE.serialize(categoryToIdMap.get(category), target); + } + } + + IntSerializer.INSTANCE.serialize(record.featureIdToBinEdges.size(), target); + for (int featureId : record.featureIdToBinEdges.keysView().toArray()) { + double[] binEdges = record.featureIdToBinEdges.get(featureId); + IntSerializer.INSTANCE.serialize(featureId, target); + DoublePrimitiveArraySerializer.INSTANCE.serialize(binEdges, target); + } + + BytePrimitiveArraySerializer.INSTANCE.serialize(record.isCategorical.toByteArray(), target); + } + + @Override + public GBTModelData deserialize(DataInputView source) throws IOException { + GBTModelData record = new GBTModelData(); + + record.type = StringSerializer.INSTANCE.deserialize(source); + record.isInputVector = BooleanSerializer.INSTANCE.deserialize(source); + + record.prior = DoubleSerializer.INSTANCE.deserialize(source); + record.stepSize = DoubleSerializer.INSTANCE.deserialize(source); + + int numTrees = IntSerializer.INSTANCE.deserialize(source); + record.allTrees = new ArrayList<>(numTrees); + for (int k = 0; k < numTrees; k += 1) { + int numTreeNodes = IntSerializer.INSTANCE.deserialize(source); + List treeNodes = new ArrayList<>(numTreeNodes); + for (int i = 0; i < numTreeNodes; i += 1) { + treeNodes.add(NODE_SERIALIZER.deserialize(source)); + } + record.allTrees.add(treeNodes); + } + + int numFeatures = IntSerializer.INSTANCE.deserialize(source); + record.featureNames = new ArrayList<>(numFeatures); + for (int k = 0; k < numFeatures; k += 1) { + String featureName = StringSerializer.INSTANCE.deserialize(source); + record.featureNames.add(featureName); + } + + int numCategoricalFeatures = IntSerializer.INSTANCE.deserialize(source); + record.categoryToIdMaps = IntObjectHashMap.newMap(); + for (int k = 0; k < numCategoricalFeatures; k += 1) { + int featureId = IntSerializer.INSTANCE.deserialize(source); + int categoryToIdMapSize = IntSerializer.INSTANCE.deserialize(source); + ObjectIntHashMap categoryToIdMap = ObjectIntHashMap.newMap(); + for (int i = 0; i < categoryToIdMapSize; i += 1) { + categoryToIdMap.put( + StringSerializer.INSTANCE.deserialize(source), + IntSerializer.INSTANCE.deserialize(source)); + } + record.categoryToIdMaps.put(featureId, categoryToIdMap); + } + + int numContinuousFeatures = IntSerializer.INSTANCE.deserialize(source); + record.featureIdToBinEdges = IntObjectHashMap.newMap(); + for (int i = 0; i < numContinuousFeatures; i += 1) { + int featureId = IntSerializer.INSTANCE.deserialize(source); + double[] binEdges = DoublePrimitiveArraySerializer.INSTANCE.deserialize(source); + record.featureIdToBinEdges.put(featureId, binEdges); + } + + record.isCategorical = + BitSet.valueOf(BytePrimitiveArraySerializer.INSTANCE.deserialize(source)); + return record; + } + + @Override + public GBTModelData deserialize(GBTModelData reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new GBTModelDataSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class GBTModelDataSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public GBTModelDataSerializerSnapshot() { + super(GBTModelDataSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java new file mode 100644 index 000000000..42fc88718 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.GBTModelData; + +/** A {@link TypeInformation} for the {@link GBTModelData} type. */ +public class GBTModelDataTypeInfo extends TypeInformation { + + public static final GBTModelDataTypeInfo INSTANCE = new GBTModelDataTypeInfo(); + + private GBTModelDataTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 2; + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public Class getTypeClass() { + return GBTModelData.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new GBTModelDataSerializer(); + } + + @Override + public String toString() { + return "SplitTypeInfo"; + } + + @Override + public boolean equals(Object o) { + return o instanceof GBTModelDataTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof GBTModelDataTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java new file mode 100644 index 000000000..f32e1176e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.GBTModelData; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * GBTModelData}. + */ +public class GBTModelDataTypeInfoFactory extends TypeInfoFactory { + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return GBTModelDataTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java new file mode 100644 index 000000000..a4bb4e28e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; + +/** Serializer for {@link Histogram}. */ +public final class HistogramSerializer extends TypeSerializerSingleton { + + public static final HistogramSerializer INSTANCE = new HistogramSerializer(); + private static final long serialVersionUID = 1L; + + private final OptimizedDoublePrimitiveArraySerializer histsSerializer = + new OptimizedDoublePrimitiveArraySerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Histogram createInstance() { + return new Histogram(); + } + + @Override + public Histogram copy(Histogram from) { + Histogram histogram = new Histogram(); + histogram.hists = ArrayUtils.subarray(from.hists, from.slice.start, from.slice.end); + histogram.slice.start = 0; + histogram.slice.end = from.slice.size(); + return histogram; + } + + @Override + public Histogram copy(Histogram from, Histogram reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Histogram record, DataOutputView target) throws IOException { + // Only writes valid slice of `hists`. + histsSerializer.serialize(record.hists, record.slice.start, record.slice.size(), target); + } + + @Override + public Histogram deserialize(DataInputView source) throws IOException { + Histogram histogram = new Histogram(); + histogram.hists = histsSerializer.deserialize(source); + histogram.slice = new Slice(0, histogram.hists.length); + return histogram; + } + + @Override + public Histogram deserialize(Histogram reuse, DataInputView source) throws IOException { + reuse.hists = histsSerializer.deserialize(reuse.hists, source); + reuse.slice.start = 0; + reuse.slice.end = reuse.hists.length; + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new HistogramSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class HistogramSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public HistogramSerializerSnapshot() { + super(HistogramSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfo.java new file mode 100644 index 000000000..d6b9823d1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.defs.Histogram; + +/** A {@link TypeInformation} for the {@link Histogram} type. */ +public class HistogramTypeInfo extends TypeInformation { + + public static final HistogramTypeInfo INSTANCE = new HistogramTypeInfo(); + + private HistogramTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 3; + } + + @Override + public int getTotalFields() { + return 3; + } + + @Override + public Class getTypeClass() { + return Histogram.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new HistogramSerializer(); + } + + @Override + public String toString() { + return "Histogram"; + } + + @Override + public boolean equals(Object o) { + return o instanceof HistogramTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof HistogramTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java new file mode 100644 index 000000000..e956e20a4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.defs.Histogram; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * Histogram}. + */ +public class HistogramTypeInfoFactory extends TypeInfoFactory { + + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return HistogramTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java new file mode 100644 index 000000000..69136cba4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.LearningNode; + +import java.io.IOException; + +/** Serializer for {@link LearningNode}. */ +public final class LearningNodeSerializer extends TypeSerializerSingleton { + + public static final LearningNodeSerializer INSTANCE = new LearningNodeSerializer(); + private static final long serialVersionUID = 1L; + + private static final SliceSerializer SLICE_SERIALIZER = SliceSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public LearningNode createInstance() { + return new LearningNode(); + } + + @Override + public LearningNode copy(LearningNode from) { + LearningNode learningNode = new LearningNode(); + learningNode.nodeIndex = from.nodeIndex; + SLICE_SERIALIZER.copy(from.slice, learningNode.slice); + SLICE_SERIALIZER.copy(from.oob, learningNode.oob); + learningNode.slice = from.slice; + return learningNode; + } + + @Override + public LearningNode copy(LearningNode from, LearningNode reuse) { + assert from.getClass() == reuse.getClass(); + reuse.nodeIndex = from.nodeIndex; + SLICE_SERIALIZER.copy(from.slice, reuse.slice); + SLICE_SERIALIZER.copy(from.oob, reuse.oob); + reuse.depth = from.depth; + return reuse; + } + + @Override + public int getLength() { + return SLICE_SERIALIZER.getLength() + 2 * IntSerializer.INSTANCE.getLength(); + } + + @Override + public void serialize(LearningNode record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.nodeIndex, target); + SLICE_SERIALIZER.serialize(record.slice, target); + SLICE_SERIALIZER.serialize(record.oob, target); + IntSerializer.INSTANCE.serialize(record.depth, target); + } + + @Override + public LearningNode deserialize(DataInputView source) throws IOException { + LearningNode learningNode = new LearningNode(); + learningNode.nodeIndex = IntSerializer.INSTANCE.deserialize(source); + learningNode.slice = SLICE_SERIALIZER.deserialize(source); + learningNode.oob = SLICE_SERIALIZER.deserialize(source); + learningNode.depth = IntSerializer.INSTANCE.deserialize(source); + return learningNode; + } + + @Override + public LearningNode deserialize(LearningNode reuse, DataInputView source) throws IOException { + reuse.nodeIndex = IntSerializer.INSTANCE.deserialize(source); + reuse.slice = SLICE_SERIALIZER.deserialize(source); + reuse.oob = SLICE_SERIALIZER.deserialize(source); + reuse.depth = IntSerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new LearningNodeSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class LearningNodeSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public LearningNodeSerializerSnapshot() { + super(LearningNodeSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java new file mode 100644 index 000000000..c6087c6d5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Node; + +import java.io.IOException; + +/** Serializer for {@link Node}. */ +public final class NodeSerializer extends TypeSerializerSingleton { + + public static final NodeSerializer INSTANCE = new NodeSerializer(); + private static final long serialVersionUID = 1L; + + private static final SplitSerializer SPLIT_SERIALIZER = SplitSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Node createInstance() { + return new Node(); + } + + @Override + public Node copy(Node from) { + Node node = new Node(); + node.split = SPLIT_SERIALIZER.copy(from.split); + node.isLeaf = from.isLeaf; + node.left = from.left; + node.right = from.right; + return node; + } + + @Override + public Node copy(Node from, Node reuse) { + assert from.getClass() == reuse.getClass(); + SPLIT_SERIALIZER.copy(from.split, reuse.split); + reuse.isLeaf = from.isLeaf; + reuse.left = from.left; + reuse.right = from.right; + return reuse; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Node record, DataOutputView target) throws IOException { + SPLIT_SERIALIZER.serialize(record.split, target); + BooleanSerializer.INSTANCE.serialize(record.isLeaf, target); + IntSerializer.INSTANCE.serialize(record.left, target); + IntSerializer.INSTANCE.serialize(record.right, target); + } + + @Override + public Node deserialize(DataInputView source) throws IOException { + Node node = new Node(); + node.split = SPLIT_SERIALIZER.deserialize(source); + node.isLeaf = BooleanSerializer.INSTANCE.deserialize(source); + node.left = IntSerializer.INSTANCE.deserialize(source); + node.right = IntSerializer.INSTANCE.deserialize(source); + return node; + } + + @Override + public Node deserialize(Node reuse, DataInputView source) throws IOException { + reuse.split = SPLIT_SERIALIZER.deserialize(source); + reuse.isLeaf = BooleanSerializer.INSTANCE.deserialize(source); + reuse.left = IntSerializer.INSTANCE.deserialize(source); + reuse.right = IntSerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new NodeSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class NodeSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public NodeSerializerSnapshot() { + super(NodeSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java new file mode 100644 index 000000000..48e8342c5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.defs.Node; + +/** A {@link TypeInformation} for the {@link Node} type. */ +public class NodeTypeInfo extends TypeInformation { + + public static final NodeTypeInfo INSTANCE = new NodeTypeInfo(); + + private NodeTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 4; + } + + @Override + public int getTotalFields() { + return 4; + } + + @Override + public Class getTypeClass() { + return Node.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new NodeSerializer(); + } + + @Override + public String toString() { + return "Node"; + } + + @Override + public boolean equals(Object o) { + return o instanceof NodeTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof NodeTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java new file mode 100644 index 000000000..de6267ed3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.defs.Node; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * Node}. + */ +public class NodeTypeInfoFactory extends TypeInfoFactory { + + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return NodeTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java new file mode 100644 index 000000000..4d5e9583a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Slice; + +import java.io.IOException; + +/** Serializer for {@link Slice}. */ +public final class SliceSerializer extends TypeSerializerSingleton { + + public static final SliceSerializer INSTANCE = new SliceSerializer(); + private static final long serialVersionUID = 1L; + + private static final SplitSerializer SPLIT_SERIALIZER = SplitSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Slice createInstance() { + return new Slice(); + } + + @Override + public Slice copy(Slice from) { + Slice slice = new Slice(); + slice.start = from.start; + slice.end = from.end; + return slice; + } + + @Override + public Slice copy(Slice from, Slice reuse) { + reuse.start = from.start; + reuse.end = from.end; + return reuse; + } + + @Override + public int getLength() { + return 2 * IntSerializer.INSTANCE.getLength(); + } + + @Override + public void serialize(Slice record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.start, target); + IntSerializer.INSTANCE.serialize(record.end, target); + } + + @Override + public Slice deserialize(DataInputView source) throws IOException { + Slice slice = new Slice(); + slice.start = IntSerializer.INSTANCE.deserialize(source); + slice.end = IntSerializer.INSTANCE.deserialize(source); + return slice; + } + + @Override + public Slice deserialize(Slice reuse, DataInputView source) throws IOException { + reuse.start = IntSerializer.INSTANCE.deserialize(source); + reuse.end = IntSerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new SliceSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SliceSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public SliceSerializerSnapshot() { + super(SliceSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java new file mode 100644 index 000000000..5c9efadfd --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.io.IOException; + +/** Specialized serializer for {@link Split}. */ +public final class SplitSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + private static final CategoricalSplitSerializer CATEGORICAL_SPLIT_SERIALIZER = + CategoricalSplitSerializer.INSTANCE; + + private static final ContinuousSplitSerializer CONTINUOUS_SPLIT_SERIALIZER = + ContinuousSplitSerializer.INSTANCE; + + public static final SplitSerializer INSTANCE = new SplitSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Split createInstance() { + return CATEGORICAL_SPLIT_SERIALIZER.createInstance(); + } + + @Override + public Split copy(Split from) { + if (from instanceof Split.CategoricalSplit) { + return CATEGORICAL_SPLIT_SERIALIZER.copy((Split.CategoricalSplit) from); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.copy((Split.ContinuousSplit) from); + } + } + + @Override + public Split copy(Split from, Split reuse) { + assert from.getClass() == reuse.getClass(); + if (from instanceof Split.CategoricalSplit) { + return CATEGORICAL_SPLIT_SERIALIZER.copy( + (Split.CategoricalSplit) from, (Split.CategoricalSplit) reuse); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.copy( + (Split.ContinuousSplit) from, (Split.ContinuousSplit) reuse); + } + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Split record, DataOutputView target) throws IOException { + if (null == record) { + target.writeByte(0); + } else if (record instanceof Split.CategoricalSplit) { + target.writeByte(1); + CATEGORICAL_SPLIT_SERIALIZER.serialize((Split.CategoricalSplit) record, target); + } else { + target.writeByte(2); + CONTINUOUS_SPLIT_SERIALIZER.serialize((Split.ContinuousSplit) record, target); + } + } + + @Override + public Split deserialize(DataInputView source) throws IOException { + byte type = source.readByte(); + if (type == 0) { + return null; + } else if (type == 1) { + return CATEGORICAL_SPLIT_SERIALIZER.deserialize(source); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.deserialize(source); + } + } + + @Override + public Split deserialize(Split reuse, DataInputView source) throws IOException { + byte type = source.readByte(); + if (type == 0) { + return null; + } + assert type == 1 && reuse instanceof Split.CategoricalSplit + || type == 2 && reuse instanceof Split.ContinuousSplit; + if (type == 1) { + return CATEGORICAL_SPLIT_SERIALIZER.deserialize((Split.CategoricalSplit) reuse, source); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.deserialize((Split.ContinuousSplit) reuse, source); + } + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new SplitSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SplitSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public SplitSerializerSnapshot() { + super(SplitSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java new file mode 100644 index 000000000..ca26987ee --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.defs.Split; + +/** A {@link TypeInformation} for the {@link Split} type. */ +public class SplitTypeInfo extends TypeInformation { + + public static final SplitTypeInfo INSTANCE = new SplitTypeInfo(); + + private SplitTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 8; + } + + @Override + public int getTotalFields() { + return 8; + } + + @Override + public Class getTypeClass() { + return Split.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new SplitSerializer(); + } + + @Override + public String toString() { + return "Split"; + } + + @Override + public boolean equals(Object o) { + return o instanceof SplitTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof SplitTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java new file mode 100644 index 000000000..68c47f4c3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.lang.reflect.Type; +import java.util.HashMap; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * Split}. + */ +public class SplitTypeInfoFactory extends TypeInfoFactory { + + private static final Map> fields; + + static { + fields = new HashMap<>(); + } + + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return SplitTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LogLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LogLoss.java new file mode 100644 index 000000000..507d21e19 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LogLoss.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; + +import org.apache.commons.math3.analysis.function.Sigmoid; + +/** + * The loss function for binary log loss. + * + *

The binary log loss defined as -y * pred + log(1 + exp(pred)) where y is a label in {0, 1} and + * pred is the predicted logit for the sample point. + */ +public class LogLoss implements LossFunc { + + public static final LogLoss INSTANCE = new LogLoss(); + private final Sigmoid sigmoid = new Sigmoid(); + + private LogLoss() {} + + @Override + public double loss(double pred, double label) { + return -label * pred + Math.log(1 + Math.exp(pred)); + } + + @Override + public double gradient(double pred, double label) { + return sigmoid.value(pred) - label; + } + + @Override + public double hessian(double pred, double label) { + double sig = sigmoid.value(pred); + return sig * (1 - sig); + } + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + throw new UnsupportedOperationException(); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + throw new UnsupportedOperationException(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java index a90967a73..e1326f88e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java @@ -48,4 +48,37 @@ public interface LossFunc extends Serializable { */ void computeGradient( LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient); + + /** + * Calculates loss given pred and label. + * + * @param pred prediction value. + * @param label label value. + * @return loss value. + */ + default double loss(double pred, double label) { + throw new UnsupportedOperationException(); + } + + /** + * Calculates value of gradient given prediction and label. + * + * @param pred prediction value. + * @param label label value. + * @return the value of gradient. + */ + default double gradient(double pred, double label) { + throw new UnsupportedOperationException(); + } + + /** + * Calculates value of second derivative, i.e. hessian, given prediction and label. + * + * @param pred prediction value. + * @param label label value. + * @return the value of second derivative, i.e. hessian. + */ + default double hessian(double pred, double label) { + throw new UnsupportedOperationException(); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/SquaredErrorLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/SquaredErrorLoss.java new file mode 100644 index 000000000..d9c3edf69 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/SquaredErrorLoss.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; + +/** + * Squared error loss function defined as (y - pred)^2 where y and pred are label and predictions + * for the instance respectively. + */ +public class SquaredErrorLoss implements LossFunc { + + public static final SquaredErrorLoss INSTANCE = new SquaredErrorLoss(); + + private SquaredErrorLoss() {} + + @Override + public double loss(double pred, double label) { + double error = label - pred; + return error * error; + } + + @Override + public double gradient(double pred, double label) { + return -2. * (label - pred); + } + + @Override + public double hessian(double pred, double label) { + return 2.; + } + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + throw new UnsupportedOperationException(); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + throw new UnsupportedOperationException(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java new file mode 100644 index 000000000..842deb6ff --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared featuresCols param. + * + *

{@link HasFeaturesCols} is typically used for {@link Stage}s that implement {@link + * HasLabelCol}. It is preferred to use {@link HasInputCol} for other cases. + */ +public interface HasFeaturesCols extends WithParams { + Param FEATURES_COLS = + new StringArrayParam( + "featuresCols", + "Feature column names.", + new String[] {"features"}, + ParamValidators.nonEmptyArray()); + + default String[] getFeaturesCols() { + return get(FEATURES_COLS); + } + + default T setFeaturesCols(String... value) { + return set(FEATURES_COLS, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java new file mode 100644 index 000000000..3eba2b72f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param probability column. */ +public interface HasProbabilityCol extends WithParams { + Param PROBABILITY_COL = + new StringParam( + "probabilityCol", + "Column name for predicted class conditional probabilities.", + "probability", + ParamValidators.notNull()); + + default String getProbabilityCol() { + return get(PROBABILITY_COL); + } + + default T setProbabilityCol(String value) { + return set(PROBABILITY_COL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java index 763a0df22..07948cc27 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java @@ -36,6 +36,7 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -210,15 +211,26 @@ private static double[][] findBinEdgesWithQuantileStrategy( int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][]; - double[] features = new double[numData]; for (int columnId = 0; columnId < numColumns; columnId++) { + double[] features = new double[numData]; for (int i = 0; i < numData; i++) { features[i] = input.get(i).get(columnId); } Arrays.sort(features); + int n = numData; - if (features[0] == features[numData - 1]) { + { + int validRange = numData; + while (validRange > 0 && Double.isNaN(features[validRange - 1])) { + validRange -= 1; + } + if (validRange < numData) { + features = ArrayUtils.subarray(features, 0, validRange); + } + } + + if (features[0] == features[features.length - 1]) { LOG.warn("Feature " + columnId + " is constant and the output will all be zero."); binEdges[columnId] = new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}; @@ -231,7 +243,7 @@ private static double[][] findBinEdgesWithQuantileStrategy( for (int binEdgeId = 0; binEdgeId < numBins; binEdgeId++) { tempBinEdges[binEdgeId] = features[(int) (binEdgeId * width)]; } - tempBinEdges[numBins] = features[numData - 1]; + tempBinEdges[numBins] = features[features.length - 1]; } else { tempBinEdges = features; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java index 03f2fc394..3bd429cf9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java @@ -154,6 +154,10 @@ public Row map(Row row) { DenseVector outputVec = inputVec.clone(); for (int i = 0; i < inputVec.size(); i++) { double targetFeature = inputVec.get(i); + if (Double.isNaN(targetFeature)) { + outputVec.set(i, binEdges[i].length - 1); + continue; + } int index = Arrays.binarySearch(binEdges[i], targetFeature); if (index < 0) { // Computes the index to insert. diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java new file mode 100644 index 000000000..81d06018b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.GBTRunner; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * An Estimator which implements the gradient boosting trees regression algorithm (Gradient Boosting). + * + *

The implementation has been inspired by advanced implementations like XGBoost and LightGBM. + * It supports features like regularized learning objective with second-order approximation, + * histogram-based and sparsity-aware split-finding algorithm. + * + *

The implementation of distributed system takes this work as a reference. Right now, we + * support horizontal partition of data and row-store storage of instances. + * + *

NOTE: Currently, some features are not supported yet: weighted input samples, early-stopping + * with validation set, encoding with leaf ids, etc. + */ +public class GBTRegressor + implements Estimator, GBTRegressorParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public GBTRegressor() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + public static GBTRegressor load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + @Override + public GBTRegressorModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream modelData = GBTRunner.train(inputs[0], this); + DataStream> featureImportance = + GBTRunner.getFeatureImportance(modelData); + GBTRegressorModel model = new GBTRegressorModel(); + model.setModelData( + tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), + tEnv.fromDataStream(featureImportance) + .renameColumns($("f0").as("featureImportance"))); + ParamUtils.updateExistingParams(model, getParamMap()); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java new file mode 100644 index 000000000..0c78f151a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.BaseGBTModel; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; + +import java.io.IOException; +import java.util.Collections; + +/** A Model computed by {@link GBTRegressor}. */ +public class GBTRegressorModel extends BaseGBTModel + implements GBTRegressorModelParams { + + /** + * Loads model data from path. + * + * @param tEnv A StreamTableEnvironment instance. + * @param path Model path. + * @return GBT regression model. + */ + public static GBTRegressorModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + return BaseGBTModel.load(tEnv, path); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream inputStream = tEnv.toDataStream(inputs[0]); + final String broadcastModelKey = "broadcastModelKey"; + DataStream modelDataStream = GBTModelData.getModelDataStream(modelDataTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + //noinspection unchecked + DataStream inputData = (DataStream) inputList.get(0); + return inputData.map( + new PredictLabelFunction(broadcastModelKey, getFeaturesCols()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + private static class PredictLabelFunction extends RichMapFunction { + + private final String broadcastModelKey; + private final String[] featuresCols; + private GBTModelData modelData; + + public PredictLabelFunction(String broadcastModelKey, String[] featuresCols) { + this.broadcastModelKey = broadcastModelKey; + this.featuresCols = featuresCols; + } + + @Override + public Row map(Row value) throws Exception { + if (null == modelData) { + modelData = + (GBTModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + } + IntDoubleHashMap features = modelData.rowToFeatures(value, featuresCols); + double pred = modelData.predictRaw(features); + return Row.join(value, Row.of(pred)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java new file mode 100644 index 000000000..84fe9c4f8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.ml.common.gbt.BaseGBTModelParams; + +/** + * Parameters for {@link GBTRegressorModel}. + * + * @param The class type of this instance. + */ +public interface GBTRegressorModelParams extends BaseGBTModelParams {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java new file mode 100644 index 000000000..ab21aee8b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.ml.common.gbt.BaseGBTParams; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Parameters for {@link GBTRegressor}. + * + * @param The class type of this instance. + */ +public interface GBTRegressorParams extends BaseGBTParams, GBTRegressorModelParams { + Param LOSS_TYPE = + new StringParam( + "lossType", "Loss type.", "squared", ParamValidators.inArray("squared")); + + default String getLossType() { + return get(LOSS_TYPE); + } + + default T setLossType(String value) { + return set(LOSS_TYPE, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java new file mode 100644 index 000000000..f0d62e1a5 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link GBTClassifier} and {@link GBTClassifierModel}. */ +public class GBTClassifierTest extends AbstractTestBase { + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., 1., 0., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., 2., 0., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., 3., 0., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., 4., 0., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., 5., 0., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., 1., 1., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., 2., 1., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., 3., 1., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., 4., 1., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., 5., 1., Vectors.dense(15.3, 1, 4.))); + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + List outputRows = + Arrays.asList( + Row.of( + 0.0, + Vectors.dense(2.376078066514637, -2.376078066514637), + Vectors.dense(0.914984852695779, 0.08501514730422102)), + Row.of( + 1.0, + Vectors.dense(-2.5493892913102703, 2.5493892913102703), + Vectors.dense(0.07246752402942669, 0.9275324759705733)), + Row.of( + 1.0, + Vectors.dense(-2.658830586839206, 2.658830586839206), + Vectors.dense(0.06544682253255263, 0.9345531774674474)), + Row.of( + 0.0, + Vectors.dense(2.3309355512336296, -2.3309355512336296), + Vectors.dense(0.9114069063091061, 0.08859309369089385)), + Row.of( + 1.0, + Vectors.dense(-2.6577392865785714, 2.6577392865785714), + Vectors.dense(0.06551360197733425, 0.9344863980226658)), + Row.of( + 0.0, + Vectors.dense(2.5532653631402114, -2.5532653631402114), + Vectors.dense(0.9277925785910718, 0.07220742140892823)), + Row.of( + 0.0, + Vectors.dense(2.3773197509703996, -2.3773197509703996), + Vectors.dense(0.9150813905583675, 0.0849186094416325)), + Row.of( + 1.0, + Vectors.dense(-2.132645378098387, 2.132645378098387), + Vectors.dense(0.10596411850817689, 0.8940358814918231)), + Row.of( + 0.0, + Vectors.dense(2.3105035625447106, -2.3105035625447106), + Vectors.dense(0.9097432116019103, 0.09025678839808973)), + Row.of( + 1.0, + Vectors.dense(-2.0541952729346695, 2.0541952729346695), + Vectors.dense(0.11362915817869357, 0.8863708418213064))); + private StreamTableEnvironment tEnv; + private Table inputTable; + + private static void verifyPredictionResult(Table output, List expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + //noinspection unchecked + List results = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + final double delta = 1e-3; + final Comparator denseVectorComparator = + new TestUtils.DenseVectorComparatorWithDelta(delta); + final Comparator comparator = + Comparator.comparing(d -> d.getFieldAs(0)) + .thenComparing(d -> d.getFieldAs(1), denseVectorComparator) + .thenComparing(d -> d.getFieldAs(2), denseVectorComparator); + TestUtils.compareResultCollectionsWithComparator(expected, results, comparator); + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + Types.DOUBLE, + Types.DOUBLE, + VectorTypeInfo.INSTANCE + }, + new String[] { + "f0", "f1", "f2", "label", "weight", "cls_label", "vec" + }))); + } + + @Test + public void testParam() { + GBTClassifier gbtc = new GBTClassifier(); + Assert.assertArrayEquals(new String[] {"features"}, gbtc.getFeaturesCols()); + Assert.assertEquals("label", gbtc.getLabelCol()); + Assert.assertArrayEquals(new String[] {}, gbtc.getCategoricalCols()); + Assert.assertEquals("prediction", gbtc.getPredictionCol()); + + Assert.assertNull(gbtc.getLeafCol()); + Assert.assertNull(gbtc.getWeightCol()); + Assert.assertEquals(5, gbtc.getMaxDepth()); + Assert.assertEquals(32, gbtc.getMaxBins()); + Assert.assertEquals(1, gbtc.getMinInstancesPerNode()); + Assert.assertEquals(0., gbtc.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(0., gbtc.getMinInfoGain(), 1e-12); + Assert.assertEquals(20, gbtc.getMaxIter()); + Assert.assertEquals(.1, gbtc.getStepSize(), 1e-12); + Assert.assertEquals(GBTClassifier.class.getName().hashCode(), gbtc.getSeed()); + Assert.assertEquals(1., gbtc.getSubsamplingRate(), 1e-12); + Assert.assertEquals("auto", gbtc.getFeatureSubsetStrategy()); + Assert.assertNull(gbtc.getValidationIndicatorCol()); + Assert.assertEquals(.01, gbtc.getValidationTol(), 1e-12); + Assert.assertEquals(0., gbtc.getRegLambda(), 1e-12); + Assert.assertEquals(1., gbtc.getRegGamma(), 1e-12); + + Assert.assertEquals("logistic", gbtc.getLossType()); + Assert.assertEquals("rawPrediction", gbtc.getRawPredictionCol()); + Assert.assertEquals("probability", gbtc.getProbabilityCol()); + + gbtc.setFeaturesCols("f0", "f1", "f2") + .setLabelCol("cls_label") + .setCategoricalCols("f0", "f1") + .setPredictionCol("pred") + .setLeafCol("leaf") + .setWeightCol("weight") + .setMaxDepth(6) + .setMaxBins(64) + .setMinInstancesPerNode(2) + .setMinWeightFractionPerNode(.1) + .setMinInfoGain(.1) + .setMaxIter(10) + .setStepSize(.2) + .setSeed(123) + .setSubsamplingRate(.8) + .setFeatureSubsetStrategy("0.5") + .setValidationIndicatorCol("val") + .setValidationTol(.1) + .setRegLambda(.1) + .setRegGamma(.1) + .setRawPredictionCol("raw_pred") + .setProbabilityCol("prob"); + + Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtc.getFeaturesCols()); + Assert.assertEquals("cls_label", gbtc.getLabelCol()); + Assert.assertArrayEquals(new String[] {"f0", "f1"}, gbtc.getCategoricalCols()); + Assert.assertEquals("pred", gbtc.getPredictionCol()); + + Assert.assertEquals("leaf", gbtc.getLeafCol()); + Assert.assertEquals("weight", gbtc.getWeightCol()); + Assert.assertEquals(6, gbtc.getMaxDepth()); + Assert.assertEquals(64, gbtc.getMaxBins()); + Assert.assertEquals(2, gbtc.getMinInstancesPerNode()); + Assert.assertEquals(.1, gbtc.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(.1, gbtc.getMinInfoGain(), 1e-12); + Assert.assertEquals(10, gbtc.getMaxIter()); + Assert.assertEquals(.2, gbtc.getStepSize(), 1e-12); + Assert.assertEquals(123, gbtc.getSeed()); + Assert.assertEquals(.8, gbtc.getSubsamplingRate(), 1e-12); + Assert.assertEquals("0.5", gbtc.getFeatureSubsetStrategy()); + Assert.assertEquals("val", gbtc.getValidationIndicatorCol()); + Assert.assertEquals(.1, gbtc.getValidationTol(), 1e-12); + Assert.assertEquals(.1, gbtc.getRegLambda(), 1e-12); + Assert.assertEquals(.1, gbtc.getRegGamma(), 1e-12); + + Assert.assertEquals("raw_pred", gbtc.getRawPredictionCol()); + Assert.assertEquals("prob", gbtc.getProbabilityCol()); + } + + @Test + public void testOutputSchema() throws Exception { + GBTClassifier gbtc = + new GBTClassifier().setFeaturesCols("f0", "f1", "f2").setCategoricalCols("f2"); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = model.transform(inputTable)[0]; + Assert.assertArrayEquals( + ArrayUtils.addAll( + inputTable.getResolvedSchema().getColumnNames().toArray(new String[0]), + gbtc.getPredictionCol(), + gbtc.getRawPredictionCol(), + gbtc.getProbabilityCol()), + output.getResolvedSchema().getColumnNames().toArray(new String[0])); + } + + @Test + public void testFitAndPredict() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testFitAndPredictWithVectorCol() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("vec") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + List outputRowsUsingVectorCol = + Arrays.asList( + Row.of( + 0.0, + Vectors.dense(1.9834935486026828, -1.9834935486026828), + Vectors.dense(0.8790530839977041, 0.12094691600229594)), + Row.of( + 1.0, + Vectors.dense(-1.9962334686995544, 1.9962334686995544), + Vectors.dense(0.11959895119804398, 0.880401048801956)), + Row.of( + 0.0, + Vectors.dense(2.2596958412285053, -2.2596958412285053), + Vectors.dense(0.9054836034255209, 0.0945163965744791)), + Row.of( + 1.0, + Vectors.dense(-2.23023965816558, 2.23023965816558), + Vectors.dense(0.09706763399626683, 0.9029323660037332)), + Row.of( + 1.0, + Vectors.dense(-2.520667396406638, 2.520667396406638), + Vectors.dense(0.0744219596185437, 0.9255780403814563)), + Row.of( + 0.0, + Vectors.dense(2.5005544570205114, -2.5005544570205114), + Vectors.dense(0.9241806803368346, 0.07581931966316532)), + Row.of( + 0.0, + Vectors.dense(2.155310746068554, -2.155310746068554), + Vectors.dense(0.8961640042377698, 0.10383599576223027)), + Row.of( + 1.0, + Vectors.dense(-2.2386996519306424, 2.2386996519306424), + Vectors.dense(0.09632867690962832, 0.9036713230903717)), + Row.of( + 0.0, + Vectors.dense(2.0375281995821273, -2.0375281995821273), + Vectors.dense(0.8846813338862343, 0.11531866611376576)), + Row.of( + 1.0, + Vectors.dense(-1.9751553623558855, 1.9751553623558855), + Vectors.dense(0.12183622723878906, 0.8781637727612109))); + verifyPredictionResult(output, outputRowsUsingVectorCol); + } + + @Test + public void testFitAndPredictWithNoCategoricalCols() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("f0", "f1") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(5) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + List outputRowsUsingNoCategoricalCols = + Arrays.asList( + Row.of( + 0.0, + Vectors.dense(2.34563907006811, -2.34563907006811), + Vectors.dense(0.9125869728543822, 0.0874130271456178)), + Row.of( + 1.0, + Vectors.dense(-2.3303467465269785, 2.3303467465269785), + Vectors.dense(0.0886406478666607, 0.9113593521333393)), + Row.of( + 1.0, + Vectors.dense(-2.6627806586536007, 2.6627806586536007), + Vectors.dense(0.06520563648648892, 0.9347943635135111)), + Row.of( + 0.0, + Vectors.dense(2.2219234863111987, -2.2219234863111987), + Vectors.dense(0.9022010445561748, 0.09779895544382528)), + Row.of( + 1.0, + Vectors.dense(-2.4261826518456586, 2.4261826518456586), + Vectors.dense(0.08119780449041314, 0.9188021955095869)), + Row.of( + 0.0, + Vectors.dense(2.6577392865785714, -2.6577392865785714), + Vectors.dense(0.9344863980226659, 0.06551360197733418)), + Row.of( + 1.0, + Vectors.dense(-2.6641132494818254, 2.6641132494818254), + Vectors.dense(0.0651244569774293, 0.9348755430225707)), + Row.of( + 0.0, + Vectors.dense(2.6577392865785714, -2.6577392865785714), + Vectors.dense(0.9344863980226659, 0.06551360197733418)), + Row.of( + 0.0, + Vectors.dense(2.6577392865785714, -2.6577392865785714), + Vectors.dense(0.9344863980226659, 0.06551360197733418)), + Row.of( + 1.0, + Vectors.dense(-2.4318453555603523, 2.4318453555603523), + Vectors.dense(0.08077634070928352, 0.9192236592907165))); + verifyPredictionResult(output, outputRowsUsingNoCategoricalCols); + } + + @Test + public void testEstimatorSaveLoadAndPredict() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifier loadedGbtc = + TestUtils.saveAndReload( + tEnv, gbtc, tempFolder.newFolder().getAbsolutePath(), GBTClassifier::load); + GBTClassifierModel model = loadedGbtc.fit(inputTable); + Assert.assertEquals( + Collections.singletonList("modelData"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Assert.assertEquals( + Collections.singletonList("featureImportance"), + model.getModelData()[1].getResolvedSchema().getColumnNames()); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testModelSaveLoadAndPredict() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + GBTClassifierModel loadedModel = + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + GBTClassifierModel::load); + Table output = + loadedModel.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testGetModelData() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table modelDataTable = model.getModelData()[0]; + List modelDataColumnNames = modelDataTable.getResolvedSchema().getColumnNames(); + Assert.assertArrayEquals( + new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); + + //noinspection unchecked + List modelDataRows = + IteratorUtils.toList(tEnv.toDataStream(modelDataTable).executeAndCollect()); + Assert.assertEquals(1, modelDataRows.size()); + GBTModelData modelData = modelDataRows.get(0).getFieldAs(0); + Assert.assertNotNull(modelData); + + Assert.assertEquals(TaskType.CLASSIFICATION, TaskType.valueOf(modelData.type)); + Assert.assertFalse(modelData.isInputVector); + Assert.assertEquals(0., modelData.prior, 1e-12); + Assert.assertEquals(gbtc.getStepSize(), modelData.stepSize, 1e-12); + Assert.assertEquals(gbtc.getMaxIter(), modelData.allTrees.size()); + Assert.assertEquals(gbtc.getCategoricalCols().length, modelData.categoryToIdMaps.size()); + Assert.assertEquals( + gbtc.getFeaturesCols().length - gbtc.getCategoricalCols().length, + modelData.featureIdToBinEdges.size()); + Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); + + Table featureImportanceTable = model.getModelData()[1]; + Assert.assertEquals( + Collections.singletonList("featureImportance"), + featureImportanceTable.getResolvedSchema().getColumnNames()); + //noinspection unchecked + List featureImportanceRows = + IteratorUtils.toList(tEnv.toDataStream(featureImportanceTable).executeAndCollect()); + Assert.assertEquals(1, featureImportanceRows.size()); + Map featureImportanceMap = + featureImportanceRows.get(0).getFieldAs("featureImportance"); + Assert.assertArrayEquals( + gbtc.getFeaturesCols(), featureImportanceMap.keySet().toArray(new String[0])); + } + + @Test + public void testSetModelData() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel modelA = gbtc.fit(inputTable); + GBTClassifierModel modelB = new GBTClassifierModel().setModelData(modelA.getModelData()); + ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); + Table output = + modelA.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java new file mode 100644 index 000000000..567a68651 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.junit.Assert; +import org.junit.Test; + +/** Test {@link DataUtils}. */ +public class DataUtilsTest { + @Test + public void testFindBin() { + double[] binEdges = new double[] {1., 2., 3., 4.}; + for (int i = 0; i < binEdges.length; i += 1) { + Assert.assertEquals( + Math.min(binEdges.length - 2, i), DataUtils.findBin(binEdges, binEdges[i])); + } + double[] values = new double[] {.5, 1.5, 2.5, 3.5, 4.5}; + int[] bins = new int[] {0, 0, 1, 2, 2}; + for (int i = 0; i < values.length; i += 1) { + Assert.assertEquals(bins[i], DataUtils.findBin(binEdges, values[i])); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java new file mode 100644 index 000000000..3a8ca1573 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.test.util.TestBaseUtils; +import org.apache.flink.testutils.junit.SharedObjects; +import org.apache.flink.testutils.junit.SharedReference; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; + +/** Tests {@link Preprocess}. */ +public class PreprocessTest extends AbstractTestBase { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + @Rule public final SharedObjects sharedObjects = SharedObjects.create(); + + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., Vectors.dense(15.3, 1, 4.))); + + private StreamTableEnvironment tEnv; + private Table inputTable; + private SharedReference> actualMeta; + + // private static void verifyPredictionResult(Table output, List expected) throws + // Exception { + // StreamTableEnvironment tEnv = + // (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + // //noinspection unchecked + // List results = + // IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + // final double delta = 1e-3; + // final Comparator denseVectorComparator = + // new TestUtils.DenseVectorComparatorWithDelta(delta); + // final Comparator comparator = + // Comparator.comparing(d -> d.getFieldAs(0)) + // .thenComparing(d -> d.getFieldAs(1), denseVectorComparator) + // .thenComparing(d -> d.getFieldAs(2), denseVectorComparator); + // TestUtils.compareResultCollectionsWithComparator(expected, results, comparator); + // } + + @Before + public void before() { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + DenseVectorTypeInfo.INSTANCE + }, + new String[] {"f0", "f1", "f2", "label", "vec"}))); + + actualMeta = sharedObjects.add(new ArrayBlockingQueue<>(8)); + } + + private static class CollectSink implements SinkFunction { + private final SharedReference> q; + + public CollectSink(SharedReference> q) { + this.q = q; + } + + @Override + public void invoke(T value, Context context) { + q.get().add(value); + } + } + + @Test + public void testPreprocessCols() throws Exception { + BoostingStrategy strategy = new BoostingStrategy(); + strategy.isInputVector = false; + strategy.featuresCols = new String[] {"f0", "f1", "f2"}; + strategy.categoricalCols = new String[] {"f2"}; + strategy.labelCol = "label"; + strategy.maxBins = 3; + Tuple2> results = + Preprocess.preprocessCols(inputTable, strategy); + + actualMeta.get().clear(); + results.f1.addSink(new CollectSink<>(actualMeta)); + //noinspection unchecked + List preprocessedRows = + IteratorUtils.toList(tEnv.toDataStream(results.f0).executeAndCollect()); + + // TODO: Correct `binEdges` of feature `f0` after FLINK-30734 resolved. + List expectedMeta = + Arrays.asList( + FeatureMeta.continuous("f0", 3, new double[] {1.2, 4.5, 13.9, 15.3}), + FeatureMeta.continuous("f1", 3, new double[] {1.0, 2.0, 4.0, 5.0}), + FeatureMeta.categorical("f2", 5, new String[] {"a", "b", "c", "d", "e"})); + + List expectedPreprocessedRows = + Arrays.asList( + Row.of(40.0, 0, 1, 5.0), + Row.of(40.0, 0, 1, 1.0), + Row.of(40.0, 0, 2, 2.0), + Row.of(40.0, 1, 2, 0.0), + Row.of(40.0, 1, 1, 1.0), + Row.of(41.0, 3, 1, 2.0), + Row.of(41.0, 1, 2, 4.0), + Row.of(41.0, 2, 1, 1.0), + Row.of(41.0, 2, 2, 0.0), + Row.of(41.0, 2, 0, 3.0)); + Comparator preprocessedRowComparator = + Comparator.comparing(d -> d.getFieldAs(0)) + .thenComparing(d -> d.getFieldAs(1)) + .thenComparing(d -> d.getFieldAs(2)) + .thenComparing(d -> d.getFieldAs(3)); + + TestBaseUtils.compareResultCollections( + expectedPreprocessedRows, preprocessedRows, preprocessedRowComparator); + TestBaseUtils.compareResultCollections( + expectedMeta, new ArrayList<>(actualMeta.get()), Comparator.comparing(d -> d.name)); + } + + @Test + public void testPreprocessVectorCol() throws Exception { + BoostingStrategy strategy = new BoostingStrategy(); + strategy.isInputVector = true; + strategy.featuresCols = new String[] {"vec"}; + strategy.labelCol = "label"; + strategy.maxBins = 3; + Tuple2> results = + Preprocess.preprocessVecCol(inputTable, strategy); + + actualMeta.get().clear(); + results.f1.addSink(new CollectSink<>(actualMeta)); + //noinspection unchecked + List preprocessedRows = + IteratorUtils.toList(tEnv.toDataStream(results.f0).executeAndCollect()); + + // TODO: Correct `binEdges` of feature `_vec_f0` and `_vec_f2` after FLINK-30734 resolved. + List expectedMeta = + Arrays.asList( + FeatureMeta.continuous("_vec_f0", 3, new double[] {1.2, 4.5, 13.9, 15.3}), + FeatureMeta.continuous("_vec_f1", 3, new double[] {1.0, 2.0, 4.0, 5.0}), + FeatureMeta.continuous("_vec_f2", 3, new double[] {1.0, 2.0, 3.0, 5.0})); + List expectedPreprocessedRows = + Arrays.asList( + Row.of(40.0, Vectors.dense(0, 1, 3.0)), + Row.of(40.0, Vectors.dense(0, 1, 1.0)), + Row.of(40.0, Vectors.dense(0, 2, 2.0)), + Row.of(40.0, Vectors.dense(1, 2, 0.0)), + Row.of(40.0, Vectors.dense(1, 1, 1.0)), + Row.of(41.0, Vectors.dense(3, 1, 2.0)), + Row.of(41.0, Vectors.dense(1, 2, 2.0)), + Row.of(41.0, Vectors.dense(2, 1, 1.0)), + Row.of(41.0, Vectors.dense(2, 2, 0.0)), + Row.of(41.0, Vectors.dense(2, 0, 2.0))); + + Comparator preprocessedRowComparator = + Comparator.comparing(d -> d.getFieldAs(0)) + .thenComparing(d -> d.getFieldAs(1), TestUtils::compare); + + TestBaseUtils.compareResultCollections( + expectedPreprocessedRows, preprocessedRows, preprocessedRowComparator); + TestBaseUtils.compareResultCollections( + expectedMeta, new ArrayList<>(actualMeta.get()), Comparator.comparing(d -> d.name)); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java new file mode 100644 index 000000000..db1366d85 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link GBTRegressor} and {@link GBTRegressorModel}. */ +public class GBTRegressorTest extends AbstractTestBase { + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., 1., 0., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., 2., 0., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., 3., 0., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., 4., 0., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., 5., 0., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., 1., 1., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., 2., 1., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., 3., 1., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., 4., 1., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., 5., 1., Vectors.dense(15.3, 1, 4.))); + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + List outputRows = + Arrays.asList( + Row.of(40.06841194119824), + Row.of(40.94100994144195), + Row.of(40.93898887207972), + Row.of(40.14918141164082), + Row.of(40.90620397010659), + Row.of(40.06041865505043), + Row.of(40.1049148535624), + Row.of(40.88096567879293), + Row.of(40.08071914298763), + Row.of(40.86772065751431)); + + private StreamTableEnvironment tEnv; + private Table inputTable; + + private static void verifyPredictionResult(Table output, List expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + //noinspection unchecked + List results = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + final double delta = 1e-9; + final Comparator comparator = + Comparator.comparing( + d -> d.getFieldAs(0), new TestUtils.DoubleComparatorWithDelta(delta)); + TestUtils.compareResultCollectionsWithComparator(expected, results, comparator); + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + Types.DOUBLE, + Types.DOUBLE, + VectorTypeInfo.INSTANCE + }, + new String[] { + "f0", "f1", "f2", "label", "weight", "cls_label", "vec" + }))); + } + + @Test + public void testParam() { + GBTRegressor gbtr = new GBTRegressor(); + Assert.assertArrayEquals(new String[] {"features"}, gbtr.getFeaturesCols()); + Assert.assertEquals("label", gbtr.getLabelCol()); + Assert.assertArrayEquals(new String[] {}, gbtr.getCategoricalCols()); + Assert.assertEquals("prediction", gbtr.getPredictionCol()); + + Assert.assertNull(gbtr.getLeafCol()); + Assert.assertNull(gbtr.getWeightCol()); + Assert.assertEquals(5, gbtr.getMaxDepth()); + Assert.assertEquals(32, gbtr.getMaxBins()); + Assert.assertEquals(1, gbtr.getMinInstancesPerNode()); + Assert.assertEquals(0., gbtr.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(0., gbtr.getMinInfoGain(), 1e-12); + Assert.assertEquals(20, gbtr.getMaxIter()); + Assert.assertEquals(.1, gbtr.getStepSize(), 1e-12); + Assert.assertEquals(GBTRegressor.class.getName().hashCode(), gbtr.getSeed()); + Assert.assertEquals(1., gbtr.getSubsamplingRate(), 1e-12); + Assert.assertEquals("auto", gbtr.getFeatureSubsetStrategy()); + Assert.assertNull(gbtr.getValidationIndicatorCol()); + Assert.assertEquals(.01, gbtr.getValidationTol(), 1e-12); + Assert.assertEquals(0., gbtr.getRegLambda(), 1e-12); + Assert.assertEquals(1., gbtr.getRegGamma(), 1e-12); + + Assert.assertEquals("squared", gbtr.getLossType()); + + gbtr.setFeaturesCols("f0", "f1", "f2") + .setLabelCol("label") + .setCategoricalCols("f0", "f1") + .setPredictionCol("pred") + .setLeafCol("leaf") + .setWeightCol("weight") + .setMaxDepth(6) + .setMaxBins(64) + .setMinInstancesPerNode(2) + .setMinWeightFractionPerNode(.1) + .setMinInfoGain(.1) + .setMaxIter(10) + .setStepSize(.2) + .setSeed(123) + .setSubsamplingRate(.8) + .setFeatureSubsetStrategy("0.5") + .setValidationIndicatorCol("val") + .setValidationTol(.1) + .setRegLambda(.1) + .setRegGamma(.1); + + Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtr.getFeaturesCols()); + Assert.assertEquals("label", gbtr.getLabelCol()); + Assert.assertArrayEquals(new String[] {"f0", "f1"}, gbtr.getCategoricalCols()); + Assert.assertEquals("pred", gbtr.getPredictionCol()); + + Assert.assertEquals("leaf", gbtr.getLeafCol()); + Assert.assertEquals("weight", gbtr.getWeightCol()); + Assert.assertEquals(6, gbtr.getMaxDepth()); + Assert.assertEquals(64, gbtr.getMaxBins()); + Assert.assertEquals(2, gbtr.getMinInstancesPerNode()); + Assert.assertEquals(.1, gbtr.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(.1, gbtr.getMinInfoGain(), 1e-12); + Assert.assertEquals(10, gbtr.getMaxIter()); + Assert.assertEquals(.2, gbtr.getStepSize(), 1e-12); + Assert.assertEquals(123, gbtr.getSeed()); + Assert.assertEquals(.8, gbtr.getSubsamplingRate(), 1e-12); + Assert.assertEquals("0.5", gbtr.getFeatureSubsetStrategy()); + Assert.assertEquals("val", gbtr.getValidationIndicatorCol()); + Assert.assertEquals(.1, gbtr.getValidationTol(), 1e-12); + Assert.assertEquals(.1, gbtr.getRegLambda(), 1e-12); + Assert.assertEquals(.1, gbtr.getRegGamma(), 1e-12); + } + + @Test + public void testOutputSchema() throws Exception { + GBTRegressor gbtr = + new GBTRegressor().setFeaturesCols("f0", "f1", "f2").setCategoricalCols("f2"); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0]; + Assert.assertArrayEquals( + ArrayUtils.addAll( + inputTable.getResolvedSchema().getColumnNames().toArray(new String[0]), + gbtr.getPredictionCol()), + output.getResolvedSchema().getColumnNames().toArray(new String[0])); + } + + @Test + public void testFitAndPredict() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testFitAndPredictWithVectorCol() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("vec") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + List outputRowsUsingVectorCol = + Arrays.asList( + Row.of(40.11011764668384), + Row.of(40.8838231947867), + Row.of(40.064839102170275), + Row.of(40.10374937485196), + Row.of(40.909914467915144), + Row.of(40.11472131282394), + Row.of(40.88106076252836), + Row.of(40.089859516616336), + Row.of(40.90833852360301), + Row.of(40.94920075468803)); + verifyPredictionResult(output, outputRowsUsingVectorCol); + } + + @Test + public void testFitAndPredictWithNoCategoricalCols() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("f0", "f1") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(5) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + List outputRowsUsingNoCategoricalCols = + Arrays.asList( + Row.of(40.060788327295285), + Row.of(40.92126707025628), + Row.of(40.08161253493682), + Row.of(40.916655697518976), + Row.of(40.95467692795112), + Row.of(40.070253879056665), + Row.of(40.06975535946203), + Row.of(40.923228418693306), + Row.of(40.093329043797524), + Row.of(40.923115214426424)); + verifyPredictionResult(output, outputRowsUsingNoCategoricalCols); + } + + @Test + public void testEstimatorSaveLoadAndPredict() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressor loadedGbtr = + TestUtils.saveAndReload( + tEnv, gbtr, tempFolder.newFolder().getAbsolutePath(), GBTRegressor::load); + GBTRegressorModel model = loadedGbtr.fit(inputTable); + Assert.assertEquals( + Collections.singletonList("modelData"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Assert.assertEquals( + Collections.singletonList("featureImportance"), + model.getModelData()[1].getResolvedSchema().getColumnNames()); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testModelSaveLoadAndPredict() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + GBTRegressorModel loadedModel = + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + GBTRegressorModel::load); + Table output = loadedModel.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testGetModelData() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table modelDataTable = model.getModelData()[0]; + List modelDataColumnNames = modelDataTable.getResolvedSchema().getColumnNames(); + Assert.assertArrayEquals( + new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); + + //noinspection unchecked + List modelDataRows = + IteratorUtils.toList(tEnv.toDataStream(modelDataTable).executeAndCollect()); + Assert.assertEquals(1, modelDataRows.size()); + GBTModelData modelData = modelDataRows.get(0).getFieldAs(0); + Assert.assertNotNull(modelData); + + Assert.assertEquals(TaskType.REGRESSION, TaskType.valueOf(modelData.type)); + Assert.assertFalse(modelData.isInputVector); + Assert.assertEquals(40.5, modelData.prior, .5); + Assert.assertEquals(gbtr.getStepSize(), modelData.stepSize, 1e-12); + Assert.assertEquals(gbtr.getMaxIter(), modelData.allTrees.size()); + Assert.assertEquals(gbtr.getCategoricalCols().length, modelData.categoryToIdMaps.size()); + Assert.assertEquals( + gbtr.getFeaturesCols().length - gbtr.getCategoricalCols().length, + modelData.featureIdToBinEdges.size()); + Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); + + Table featureImportanceTable = model.getModelData()[1]; + Assert.assertEquals( + Collections.singletonList("featureImportance"), + featureImportanceTable.getResolvedSchema().getColumnNames()); + //noinspection unchecked + List featureImportanceRows = + IteratorUtils.toList(tEnv.toDataStream(featureImportanceTable).executeAndCollect()); + Assert.assertEquals(1, featureImportanceRows.size()); + Map featureImportanceMap = + featureImportanceRows.get(0).getFieldAs("featureImportance"); + Assert.assertArrayEquals( + gbtr.getFeaturesCols(), featureImportanceMap.keySet().toArray(new String[0])); + } + + @Test + public void testSetModelData() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel modelA = gbtr.fit(inputTable); + GBTRegressorModel modelB = new GBTRegressorModel().setModelData(modelA.getModelData()); + ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); + Table output = modelA.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } +} diff --git a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py index db59df0b7..ad5e6210b 100644 --- a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py +++ b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py @@ -92,7 +92,8 @@ def module(self): pass def exclude_java_stage(self): - return [] + return ['gbtclassifier.GBTClassifier', 'gbtclassifier.GBTClassifierModel', + 'gbtregressor.GBTRegressor', 'gbtregressor.GBTRegressorModel'] class ClassificationCompletenessTest(CompletenessTest, MLLibTest): diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java index 5b6f984aa..de2a882d9 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java @@ -25,7 +25,6 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.util.Bits; import java.io.IOException; import java.util.Arrays; @@ -38,7 +37,8 @@ public final class DenseVectorSerializer extends TypeSerializer { private static final double[] EMPTY = new double[0]; - private final byte[] buf = new byte[1024]; + private final OptimizedDoublePrimitiveArraySerializer valuesSerializer = + new OptimizedDoublePrimitiveArraySerializer(); @Override public boolean isImmutableType() { @@ -79,53 +79,21 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept if (vector == null) { throw new IllegalArgumentException("The vector must not be null."); } - - final int len = vector.values.length; - target.writeInt(len); - - for (int i = 0; i < len; i++) { - Bits.putDouble(buf, (i & 127) << 3, vector.values[i]); - if ((i & 127) == 127) { - target.write(buf); - } - } - target.write(buf, 0, (len & 127) << 3); + valuesSerializer.serialize(vector.values, target); } @Override public DenseVector deserialize(DataInputView source) throws IOException { - int len = source.readInt(); - double[] values = new double[len]; - readDoubleArray(values, source, len); - return new DenseVector(values); - } - - // Reads `len` double values from `source` into `dst`. - private void readDoubleArray(double[] dst, DataInputView source, int len) throws IOException { - int index = 0; - for (int i = 0; i < (len >> 7); i++) { - source.readFully(buf, 0, 1024); - for (int j = 0; j < 128; j++) { - dst[index++] = Bits.getDouble(buf, j << 3); - } - } - source.readFully(buf, 0, (len << 3) & 1023); - for (int j = 0; j < (len & 127); j++) { - dst[index++] = Bits.getDouble(buf, j << 3); - } + return new DenseVector(valuesSerializer.deserialize(source)); } @Override public DenseVector deserialize(DenseVector reuse, DataInputView source) throws IOException { int len = source.readInt(); if (len == reuse.values.length) { - readDoubleArray(reuse.values, source, len); - return reuse; + valuesSerializer.deserialize(reuse.values, source); } - - double[] values = new double[len]; - readDoubleArray(values, source, len); - return new DenseVector(values); + return new DenseVector(valuesSerializer.deserialize(source)); } @Override diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java new file mode 100644 index 000000000..7af39064b --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.util.Bits; + +import java.io.IOException; +import java.util.Objects; + +/** A serializer for double arrays. */ +@Internal +public final class OptimizedDoublePrimitiveArraySerializer extends TypeSerializer { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY = new double[0]; + + private static final int BUFFER_SIZE = 1024; + private final byte[] buf = new byte[BUFFER_SIZE]; + + public OptimizedDoublePrimitiveArraySerializer() {} + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return new OptimizedDoublePrimitiveArraySerializer(); + } + + @Override + public double[] createInstance() { + return EMPTY; + } + + @Override + public double[] copy(double[] from) { + double[] copy = new double[from.length]; + System.arraycopy(from, 0, copy, 0, from.length); + return copy; + } + + @Override + public double[] copy(double[] from, double[] reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(double[] record, DataOutputView target) throws IOException { + if (record == null) { + throw new IllegalArgumentException("The record must not be null."); + } + serialize(record, 0, record.length, target); + } + + public void serialize(double[] record, int start, int len, DataOutputView target) + throws IOException { + target.writeInt(len); + for (int i = 0; i < len; i += 1) { + Bits.putDouble(buf, (i & 127) << 3, record[start + i]); + if ((i & 127) == 127) { + target.write(buf); + } + } + target.write(buf, 0, (len & 127) << 3); + } + + @Override + public double[] deserialize(DataInputView source) throws IOException { + final int len = source.readInt(); + double[] result = new double[len]; + readDoubleArray(len, result, source); + return result; + } + + public void readDoubleArray(int len, double[] result, DataInputView source) throws IOException { + int index = 0; + for (int i = 0; i < (len >> 7); i++) { + source.readFully(buf, 0, 1024); + for (int j = 0; j < 128; j++) { + result[index++] = Bits.getDouble(buf, j << 3); + } + } + source.readFully(buf, 0, (len & 127) << 3); + for (int j = 0; j < (len & 127); j++) { + result[index++] = Bits.getDouble(buf, j << 3); + } + } + + @Override + public double[] deserialize(double[] reuse, DataInputView source) throws IOException { + int len = source.readInt(); + if (len == reuse.length) { + readDoubleArray(len, reuse, source); + return reuse; + } + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + final int len = source.readInt(); + target.writeInt(len); + target.write(source, len * Double.BYTES); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof OptimizedDoublePrimitiveArraySerializer; + } + + @Override + public int hashCode() { + return Objects.hashCode(OptimizedDoublePrimitiveArraySerializer.class); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new DoublePrimitiveArraySerializerSnapshot(); + } + + // ------------------------------------------------------------------------ + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class DoublePrimitiveArraySerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public DoublePrimitiveArraySerializerSnapshot() { + super(OptimizedDoublePrimitiveArraySerializer::new); + } + } +}