diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/anomalydetection/IsolationForestExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/anomalydetection/IsolationForestExample.java new file mode 100644 index 000000000..fa11efcf3 --- /dev/null +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/anomalydetection/IsolationForestExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.anomalydetection; + +import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForest; +import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForestModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +/** Simple program that creates an IsolationForest instance and uses it for anomaly detection. */ +public class IsolationForestExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream inputStream = + env.fromElements( + Vectors.dense(1, 2), + Vectors.dense(1.1, 2), + Vectors.dense(1, 2.1), + Vectors.dense(1.1, 2.1), + Vectors.dense(0.1, 0.1)); + + Table inputTable = tEnv.fromDataStream(inputStream).as("features"); + + IsolationForest isolationForest = + new IsolationForest() + .setNumTrees(100) + .setMaxIter(10) + .setMaxSamples(256) + .setMaxFeatures(1.0); + + IsolationForestModel isolationForestModel = isolationForest.fit(inputTable); + + Table outputTable = isolationForestModel.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + DenseVector features = (DenseVector) row.getField(isolationForest.getFeaturesCol()); + int predictId = (Integer) row.getField(isolationForest.getPredictionCol()); + System.out.printf("Features: %s \tPrediction: %s\n", features, predictId); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IForest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IForest.java new file mode 100644 index 000000000..7f31a6952 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IForest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.ml.linalg.DenseVector; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** Construct isolation forest. */ +public class IForest implements Serializable { + public final int numTrees; + public List iTreeList; + public Double center0; + public Double center1; + public int subSamplesSize; + + public IForest(int numTrees) { + this.numTrees = numTrees; + this.iTreeList = new ArrayList<>(256); + this.center0 = null; + this.center1 = null; + } + + public void generateIsolationForest(DenseVector[] samplesData, int[] featureIndices) { + int n = samplesData.length; + this.subSamplesSize = Math.min(256, n); + int limitHeight = (int) Math.ceil(Math.log(Math.max(subSamplesSize, 2)) / Math.log(2)); + Random randomState = new Random(System.nanoTime()); + for (int i = 0; i < numTrees; i++) { + DenseVector[] subSamples = new DenseVector[subSamplesSize]; + for (int j = 0; j < subSamplesSize; j++) { + int r = randomState.nextInt(n); + subSamples[j] = samplesData[r]; + } + ITree isolationTree = + ITree.generateIsolationTree( + subSamples, 0, limitHeight, randomState, featureIndices); + this.iTreeList.add(isolationTree); + } + } + + public DenseVector calculateScore(DenseVector[] samplesData) throws Exception { + int n = samplesData.length; + DenseVector score = new DenseVector(n); + for (int i = 0; i < n; i++) { + double pathLengthSum = 0; + for (ITree isolationTree : iTreeList) { + pathLengthSum += ITree.calculatePathLength(samplesData[i], isolationTree); + } + + double pathLengthAvg = pathLengthSum / iTreeList.size(); + double cn = ITree.calculateCn(subSamplesSize); + double index = pathLengthAvg / cn; + score.set(i, Math.pow(2, -index)); + } + + return score; + } + + public DenseVector classifyByCluster(DenseVector score, int iters) { + int scoresSize = score.size(); + this.center0 = score.get(0); // Cluster center of abnormal + this.center1 = score.get(0); // Cluster center of normal + + for (int p = 1; p < score.size(); p++) { + if (score.get(p) > center0) { + center0 = score.get(p); + } + + if (score.get(p) < center1) { + center1 = score.get(p); + } + } + + int cnt0, cnt1; + double diff0, diff1; + int[] labels = new int[scoresSize]; + + for (int i = 0; i < iters; i++) { + cnt0 = 0; + cnt1 = 0; + + for (int j = 0; j < scoresSize; j++) { + diff0 = Math.abs(score.get(j) - center0); + diff1 = Math.abs(score.get(j) - center1); + + if (diff0 < diff1) { + labels[j] = 0; + cnt0++; + } else { + labels[j] = 1; + cnt1++; + } + } + + diff0 = center0; + diff1 = center1; + + center0 = 0.0; + center1 = 0.0; + for (int k = 0; k < scoresSize; k++) { + if (labels[k] == 0) { + center0 += score.get(k); + } else { + center1 += score.get(k); + } + } + + center0 /= cnt0; + center1 /= cnt1; + + if (center0 - diff0 <= 1e-6 && center1 - diff1 <= 1e-6) { + break; + } + } + return new DenseVector(new double[] {center0, center1}); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/ITree.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/ITree.java new file mode 100644 index 000000000..c96f7b252 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/ITree.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.linalg.DenseVector; + +import java.io.Serializable; +import java.util.Random; + +/** Construct isolation tree. */ +public class ITree implements Serializable { + public final int attributeIndex; + public final double splitAttributeValue; + public ITree leftTree; + public ITree rightTree; + public int currentHeight; + public int leafNodesNum; + + public ITree(int attributeIndex, double splitAttributeValue) { + this.attributeIndex = attributeIndex; + this.splitAttributeValue = splitAttributeValue; + this.leftTree = null; + this.rightTree = null; + this.currentHeight = 0; + this.leafNodesNum = 1; + } + + public static ITree generateIsolationTree( + DenseVector[] samplesData, + int currentHeight, + int limitHeight, + Random randomState, + int[] featureIndices) { + int n = samplesData.length; + ITree isolationTree; + if (samplesData.length == 0) { + return null; + } else if (samplesData.length == 1 || currentHeight >= limitHeight) { + isolationTree = new ITree(0, samplesData[0].get(0)); + isolationTree.currentHeight = currentHeight; + isolationTree.leafNodesNum = samplesData.length; + return isolationTree; + } + boolean flag = true; + for (int i = 1; i < n; i++) { + if (!samplesData[i].equals(samplesData[i - 1])) { + flag = false; + break; + } + } + if (flag) { + isolationTree = new ITree(0, samplesData[0].get(0)); + isolationTree.currentHeight = currentHeight; + isolationTree.leafNodesNum = samplesData.length; + return isolationTree; + } + + Tuple2 tuple2 = + getRandomFeatureToSplit(samplesData, randomState, featureIndices); + int attributeIndex = tuple2.f0; + double splitAttributeValue = tuple2.f1; + + int leftNodesNum = 0; + int rightNodesNum = 0; + for (DenseVector datum : samplesData) { + if (datum.get(attributeIndex) < splitAttributeValue) { + leftNodesNum++; + } else { + rightNodesNum++; + } + } + + DenseVector[] leftSamples = new DenseVector[leftNodesNum]; + DenseVector[] rightSamples = new DenseVector[rightNodesNum]; + int l = 0, r = 0; + for (DenseVector samplesDatum : samplesData) { + if (samplesDatum.get(attributeIndex) < splitAttributeValue) { + leftSamples[l++] = samplesDatum; + } else { + rightSamples[r++] = samplesDatum; + } + } + + ITree root = new ITree(attributeIndex, splitAttributeValue); + root.currentHeight = currentHeight; + root.leafNodesNum = samplesData.length; + root.leftTree = + generateIsolationTree( + leftSamples, currentHeight + 1, limitHeight, randomState, featureIndices); + root.rightTree = + generateIsolationTree( + rightSamples, currentHeight + 1, limitHeight, randomState, featureIndices); + + return root; + } + + private static Tuple2 getRandomFeatureToSplit( + DenseVector[] samplesData, Random randomState, int[] featureIndices) { + int attributeIndex = featureIndices[randomState.nextInt(featureIndices.length)]; + + double maxValue = samplesData[0].get(attributeIndex); + double minValue = samplesData[0].get(attributeIndex); + for (int i = 1; i < samplesData.length; i++) { + minValue = Math.min(minValue, samplesData[i].get(attributeIndex)); + maxValue = Math.max(maxValue, samplesData[i].get(attributeIndex)); + } + double splitAttributeValue = (maxValue - minValue) * randomState.nextDouble() + minValue; + + return Tuple2.of(attributeIndex, splitAttributeValue); + } + + public static double calculatePathLength(DenseVector sampleData, ITree isolationTree) + throws Exception { + double pathLength = -1; + ITree tmpITree = isolationTree; + while (tmpITree != null) { + pathLength += 1; + if (tmpITree.leftTree == null + || tmpITree.rightTree == null + || sampleData.get(tmpITree.attributeIndex) == tmpITree.splitAttributeValue) { + break; + } else if (sampleData.get(tmpITree.attributeIndex) < tmpITree.splitAttributeValue) { + tmpITree = tmpITree.leftTree; + } else { + tmpITree = tmpITree.rightTree; + } + } + + assert tmpITree != null; + return pathLength + calculateCn(tmpITree.leafNodesNum); + } + + public static double calculateCn(double n) { + if (n <= 1) { + return 0; + } + return 2.0 * (Math.log(n - 1.0) + 0.5772156649015329) - 2.0 * (n - 1.0) / n; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForest.java new file mode 100644 index 000000000..9352bbe5a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForest.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIter; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * An Estimator which implements the Isolation Forest algorithm. + * + *

See https://en.wikipedia.org/wiki/Isolation_forest. + */ +public class IsolationForest + implements Estimator, + IsolationForestParams { + private final Map, Object> paramMap = new HashMap<>(); + + public IsolationForest() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public IsolationForestModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + IForest iForest = new IForest(getNumTrees()); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream points = + tEnv.toDataStream(inputs[0]).map(new FormatDataMapFunction(getFeaturesCol())); + + DataStream initModelData = + selectRandomSample(points, getMaxSamples()) + .map(new InitModelData(iForest, getMaxIter(), getMaxFeatures())) + .setParallelism(1); + + DataStream finalModelData = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(initModelData), + ReplayableDataStreamList.notReplay(points), + IterationConfig.newBuilder() + .setOperatorLifeCycle( + IterationConfig.OperatorLifeCycle.ALL_ROUND) + .build(), + new IsolationForestIterationBody(getMaxIter())) + .get(0); + + Table finalModelDataTable = tEnv.fromDataStream(finalModelData); + IsolationForestModel model = new IsolationForestModel().setModelData(finalModelDataTable); + ParamUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static IsolationForest load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + private static class FormatDataMapFunction implements MapFunction { + private final String featuresCol; + + public FormatDataMapFunction(String featuresCol) { + this.featuresCol = featuresCol; + } + + @Override + public DenseVector[] map(Row row) throws Exception { + List list = new ArrayList<>(256); + DenseVector denseVector = ((Vector) row.getField(featuresCol)).toDense(); + list.add(denseVector); + return list.toArray(new DenseVector[0]); + } + } + + private static DataStream selectRandomSample( + DataStream samplesData, int maxSamples) { + DataStream resultStream = + DataStreamUtils.mapPartition( + DataStreamUtils.sample(samplesData, maxSamples, System.nanoTime()), + (MapPartitionFunction) + (iterable, collector) -> { + Iterator samplesDataIterator = + iterable.iterator(); + List list = new ArrayList<>(); + while (samplesDataIterator.hasNext()) { + list.addAll(Arrays.asList(samplesDataIterator.next())); + } + collector.collect(list.toArray(new DenseVector[0])); + }, + Types.OBJECT_ARRAY(DenseVectorTypeInfo.INSTANCE)); + resultStream.getTransformation().setParallelism(1); + return resultStream; + } + + private static class InitModelData extends RichMapFunction { + private final IForest iForest; + private final int iters; + private final double maxFeatures; + + private InitModelData(IForest iForest, int iters, double maxFeatures) { + this.iForest = iForest; + this.iters = iters; + this.maxFeatures = maxFeatures; + } + + @Override + public IForest map(DenseVector[] denseVectors) throws Exception { + int n = denseVectors[0].size(); + int numFeatures = Math.min(n, Math.max(1, (int) (maxFeatures * n))); + + List tempList = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + tempList.add(i); + } + Collections.shuffle(tempList); + + int[] featuresIndicts = new int[numFeatures]; + for (int j = 0; j < numFeatures; j++) { + featuresIndicts[j] = tempList.get(j); + } + + iForest.generateIsolationForest(denseVectors, featuresIndicts); + DenseVector scores = iForest.calculateScore(denseVectors); + iForest.classifyByCluster(scores, iters); + return iForest; + } + } + + private static class IsolationForestIterationBody implements IterationBody { + private final int iters; + + public IsolationForestIterationBody(int iters) { + this.iters = iters; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream centersData = variableStreams.get(0); + DataStream samplesData = dataStreams.get(0); + final OutputTag modelDataOutputTag = + new OutputTag("IsolationForest") {}; + + SingleOutputStreamOperator terminationCriteria = + centersData.flatMap(new TerminateOnMaxIter(iters)); + + DataStream centers = + samplesData + .connect(centersData.broadcast()) + .transform( + "CentersUpdateAccumulator", + TypeInformation.of(IForest.class), + new CentersUpdateAccumulator(modelDataOutputTag, iters)); + + DataStream newModelData = + centers.countWindowAll(centers.getParallelism()) + .reduce( + new ReduceFunction() { + @Override + public IForest reduce(IForest iForest1, IForest iForest2) + throws Exception { + if (iForest2.center0 == null + || iForest2.center1 == null) { + return iForest1; + } + return iForest2; + } + }) + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap( + IForest iForest, + Collector collector) + throws Exception { + if (iForest.center0 != null + && iForest.center1 != null) { + collector.collect( + new IsolationForestModelData(iForest)); + } + } + }); + + DataStream newCenters = newModelData.map(x -> x.iForest).setParallelism(1); + + DataStream finalModelData = + newModelData.flatMap(new ForwardInputsOfLastRound<>()); + + return new IterationBodyResult( + DataStreamList.of(newCenters), + DataStreamList.of(finalModelData), + terminationCriteria); + } + } + + private static class CentersUpdateAccumulator extends AbstractStreamOperator + implements TwoInputStreamOperator, + IterationListener { + private final OutputTag modelDataOutputTag; + + private final int iters; + + private ListStateWithCache samplesData; + + private ListState samplesDataCenter; + + private ListStateWithCache samplesDataScores; + + public CentersUpdateAccumulator(OutputTag modelDataOutputTag, int iters) { + this.modelDataOutputTag = modelDataOutputTag; + this.iters = iters; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + samplesDataCenter = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor( + "centers", TypeInformation.of(IForest.class))); + + samplesData = + new ListStateWithCache<>( + Types.OBJECT_ARRAY(DenseVectorTypeInfo.INSTANCE) + .createSerializer(getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + + samplesDataScores = + new ListStateWithCache<>( + Types.OBJECT_ARRAY(DenseVectorTypeInfo.INSTANCE) + .createSerializer(getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + samplesData.snapshotState(context); + } + + @Override + public void processElement1(StreamRecord streamRecord) throws Exception { + samplesData.add(streamRecord.getValue()); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + Preconditions.checkState(!samplesDataCenter.get().iterator().hasNext()); + samplesDataCenter.add(streamRecord.getValue()); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) + throws Exception { + IForest centers = + Objects.requireNonNull( + OperatorStateUtils.getUniqueElement(samplesDataCenter, "centers") + .orElse(null)); + Iterator samplesDataIterator = samplesData.get().iterator(); + List list = new ArrayList<>(); + while (samplesDataIterator.hasNext()) { + DenseVector[] sampleData = samplesDataIterator.next(); + list.add(centers.calculateScore(sampleData)); + } + DenseVector[] scores = list.toArray(new DenseVector[0]); + samplesDataScores.add(scores); + + collector.collect(samplesDataCenter.get().iterator().next()); + samplesDataCenter.clear(); + } + + @Override + public void onIterationTerminated(Context context, Collector collector) + throws Exception { + IForest centers = + Objects.requireNonNull( + OperatorStateUtils.getUniqueElement(samplesDataCenter, "centers") + .orElse(null)); + double centers0Sum1 = 0.0; + double centers1Sum1 = 0.0; + double centers0Sum2 = 0.0; + double centers1Sum2 = 0.0; + int size1 = 0; + int size2 = 0; + Iterator samplesDataScoresIterator = samplesDataScores.get().iterator(); + while (samplesDataScoresIterator.hasNext()) { + for (DenseVector denseVector : samplesDataScoresIterator.next()) { + DenseVector denseVector1 = centers.classifyByCluster(denseVector, iters); + centers0Sum1 += denseVector1.get(0); + centers1Sum1 += denseVector1.get(1); + size1++; + } + centers0Sum2 = centers0Sum1 / size1; + centers1Sum2 = centers1Sum1 / size1; + size2++; + } + + centers.center0 = centers0Sum2 / size2; + centers.center1 = centers1Sum2 / size2; + + context.output(modelDataOutputTag, centers); + + samplesDataCenter.clear(); + samplesDataScores.clear(); + samplesData.clear(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModel.java new file mode 100644 index 000000000..46061485b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModel.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.RowUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A Model which detection anomaly data using the model data computed by {@link IsolationForest}. + */ +public class IsolationForestModel + implements Model, IsolationForestModelParams { + private final Map, Object> paramMap = new HashMap<>(); + + private Table modelDataTable; + + public IsolationForestModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public IsolationForestModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream modelDataStream = + IsolationForestModelData.getModelDataStream(modelDataTable); + + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.INT_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(tEnv.toDataStream(inputs[0])), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + DataStream inputData = inputList.get(0); + return inputData.map( + new PredictLabelFunction(broadcastModelKey, getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + IsolationForestModelData.getModelDataStream(modelDataTable), + path, + new IsolationForestModelData.ModelDataEncoder()); + } + + public static IsolationForestModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new IsolationForestModelData.ModelDataDecoder()); + + IsolationForestModel model = ReadWriteUtils.loadStageParam(path); + return model.setModelData(modelDataTable); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends RichMapFunction { + private final String broadcastModelKey; + private final String featuresCol; + private IsolationForestModelData modelData = null; + public List iTreeList; + public Double center0; + public Double center1; + public int subSamplesSize; + + public PredictLabelFunction(String broadcastModelKey, String featuresCol) { + this.broadcastModelKey = broadcastModelKey; + this.featuresCol = featuresCol; + } + + @Override + public Row map(Row dataPoint) throws Exception { + if (modelData == null) { + modelData = + (IsolationForestModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + iTreeList = modelData.iForest.iTreeList; + center0 = modelData.iForest.center0; + center1 = modelData.iForest.center1; + subSamplesSize = modelData.iForest.subSamplesSize; + } + + DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense(); + int predictId = predict(point); + return RowUtils.append(dataPoint, predictId); + } + + private int predict(DenseVector sampleData) throws Exception { + double pathLengthSum = 0; + int treesNumber = iTreeList.size(); + for (int j = 0; j < treesNumber; j++) { + pathLengthSum += ITree.calculatePathLength(sampleData, iTreeList.get(j)); + } + double pathLengthAvg = pathLengthSum / treesNumber; + double cn = ITree.calculateCn(subSamplesSize); + double score = Math.pow(2, -pathLengthAvg / cn); + + return Math.abs(score - center0) > Math.abs(score - center1) ? 1 : 0; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModelData.java new file mode 100644 index 000000000..48c39015b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModelData.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link IsolationForestModel}. + * + *

This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class IsolationForestModelData { + + public final IForest iForest; + + public IsolationForestModelData(IForest iForest) { + this.iForest = iForest; + } + + public static DataStream getModelDataStream(Table modelData) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); + return tEnv.toDataStream(modelData) + .map(x -> new IsolationForestModelData((IForest) x.getField(0))); + } + + /** Data encoder for {@link IsolationForestModelData}. */ + public static class ModelDataEncoder implements Encoder { + private final TypeSerializer pojoSerializer = + TypeInformation.of(IForest.class).createSerializer(new ExecutionConfig()); + + @Override + public void encode(IsolationForestModelData modelData, OutputStream outputStream) + throws IOException { + pojoSerializer.serialize( + modelData.iForest, new DataOutputViewStreamWrapper(outputStream)); + } + } + + /** Data decoder for {@link IsolationForestModelData}. */ + public static class ModelDataDecoder extends SimpleStreamFormat { + + @Override + public Reader createReader( + Configuration configuration, FSDataInputStream fsDataInputStream) + throws IOException { + return new Reader() { + private final TypeSerializer pojoSerializer = + TypeInformation.of(IForest.class).createSerializer(new ExecutionConfig()); + + @Override + public IsolationForestModelData read() throws IOException { + try { + DataInputViewStreamWrapper inputViewStreamWrapper = + new DataInputViewStreamWrapper(fsDataInputStream); + IForest iForest1 = pojoSerializer.deserialize(inputViewStreamWrapper); + return new IsolationForestModelData(iForest1); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + fsDataInputStream.close(); + } + }; + } + + @Override + public TypeInformation getProducedType() { + return TypeInformation.of(IsolationForestModelData.class); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModelParams.java new file mode 100644 index 000000000..b2cdef215 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestModelParams.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.common.param.HasWindows; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; + +/** + * Params of {@link IsolationForestModel}. + * + * @param The class of this instance. + */ +public interface IsolationForestModelParams + extends HasMaxIter, + HasFeaturesCol, + HasPredictionCol, + HasRawPredictionCol, + HasWindows { + Param NUM_TREES = new IntParam("numTrees", "The max number of ITrees to create.", 100); + + default int getNumTrees() { + return get(NUM_TREES); + } + + default T setNumTrees(int value) { + return set(NUM_TREES, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestParams.java new file mode 100644 index 000000000..e21cf6633 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/anomalydetection/isolationforest/IsolationForestParams.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection.isolationforest; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; + +/** + * Params of {@link IsolationForest}. + * + * @param The class of this instance. + */ +public interface IsolationForestParams extends IsolationForestModelParams { + Param MAX_SAMPLES = + new IntParam( + "maxSamples", + "The number of samplesData to train and its max value is preferably 256.", + 256); + + Param MAX_FEATURES = + new DoubleParam( + "maxFeatures", + "The number of features used to train each tree and it is treated as a fraction in the range (0, 1.0].", + 1.0); + + default int getNumTrees() { + return get(NUM_TREES); + } + + default T setNumTrees(int value) { + return set(NUM_TREES, value); + } + + default int getMaxSamples() { + return get(MAX_SAMPLES); + } + + default T setMaxSamples(int value) { + return set(MAX_SAMPLES, value); + } + + default double getMaxFeatures() { + return get(MAX_FEATURES); + } + + default T setMaxFeatures(double value) { + return set(MAX_FEATURES, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/anomalydetection/IsolationForestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/anomalydetection/IsolationForestTest.java new file mode 100644 index 000000000..4dd762327 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/anomalydetection/IsolationForestTest.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.anomalydetection; + +import org.apache.flink.ml.anomalydetection.isolationforest.IForest; +import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForest; +import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForestModel; +import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForestModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests {@link IsolationForest} and {@link IsolationForestModel}. */ +public class IsolationForestTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private static final List DATA = + new ArrayList<>( + Arrays.asList( + Vectors.dense(4), + Vectors.dense(1), + Vectors.dense(4), + Vectors.dense(5), + Vectors.dense(3), + Vectors.dense(6), + Vectors.dense(2), + Vectors.dense(5), + Vectors.dense(6), + Vectors.dense(2), + Vectors.dense(5), + Vectors.dense(7), + Vectors.dense(1), + Vectors.dense(8), + Vectors.dense(15), + Vectors.dense(33), + Vectors.dense(4), + Vectors.dense(7), + Vectors.dense(6), + Vectors.dense(7), + Vectors.dense(8), + Vectors.dense(55))); + + private static final List> expectedGroups = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(4), + Vectors.dense(1), + Vectors.dense(4), + Vectors.dense(5), + Vectors.dense(3), + Vectors.dense(6), + Vectors.dense(2), + Vectors.dense(5), + Vectors.dense(6), + Vectors.dense(2), + Vectors.dense(5), + Vectors.dense(7), + Vectors.dense(1), + Vectors.dense(8), + Vectors.dense(4), + Vectors.dense(7), + Vectors.dense(6), + Vectors.dense(7), + Vectors.dense(8))), + new HashSet<>( + Arrays.asList( + Vectors.dense(15), Vectors.dense(33), Vectors.dense(55)))); + + private static final double TOLERANCE = 1e-7; + + private Table dataTable; + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + // Allow KryoSerializer Fallback. + env.getConfig().enableGenericTypes(); + tEnv = StreamTableEnvironment.create(env); + dataTable = tEnv.fromDataStream(env.fromCollection(DATA)).as("features"); + } + + /** + * Aggregates feature by predictions. Results are returned as a list of sets, where elements in + * the same set are features whose prediction results are the same. + * + * @param rows A list of rows containing feature and prediction columns + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + * @return A map containing the collected results + */ + protected static List> groupFeaturesByPrediction( + List rows, String featuresCol, String predictionCol) { + Map> map = new HashMap<>(); + for (Row row : rows) { + DenseVector vector = ((Vector) row.getField(featuresCol)).toDense(); + int predict = (Integer) row.getField(predictionCol); + map.putIfAbsent(predict, new HashSet<>()); + map.get(predict).add(vector); + } + return new ArrayList<>(map.values()); + } + + @Test + public void testParam() { + IsolationForest isolationForest = new IsolationForest(); + assertEquals("features", isolationForest.getFeaturesCol()); + assertEquals("prediction", isolationForest.getPredictionCol()); + assertEquals(256, isolationForest.getMaxSamples()); + assertEquals(1.0, isolationForest.getMaxFeatures(), TOLERANCE); + assertEquals(100, isolationForest.getNumTrees()); + + isolationForest + .setFeaturesCol("test_features") + .setPredictionCol("test_prediction") + .setMaxSamples(128) + .setMaxFeatures(0.5) + .setNumTrees(90); + + assertEquals("test_features", isolationForest.getFeaturesCol()); + assertEquals("test_prediction", isolationForest.getPredictionCol()); + assertEquals(128, isolationForest.getMaxSamples()); + assertEquals(0.5, isolationForest.getMaxFeatures(), TOLERANCE); + assertEquals(90, isolationForest.getNumTrees()); + } + + @Test + public void testOutputSchema() throws Exception { + Table input = dataTable.as("test_feature"); + IsolationForest isolationForest = + new IsolationForest() + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction"); + IsolationForestModel model = isolationForest.fit(input); + Table output = model.transform(input)[0]; + + assertEquals( + Arrays.asList("test_feature", "test_prediction"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + IsolationForest isolationForest = + new IsolationForest().setMaxSamples(256).setMaxFeatures(1.0).setNumTrees(100); + + IsolationForestModel model = isolationForest.fit(dataTable); + Table output = model.transform(dataTable)[0]; + + assertEquals( + Arrays.asList("features", "prediction"), + output.getResolvedSchema().getColumnNames()); + List results = IteratorUtils.toList(output.execute().collect()); + List> actualGroups = + groupFeaturesByPrediction( + results, + isolationForest.getFeaturesCol(), + isolationForest.getPredictionCol()); + assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + IsolationForest isolationForest = + new IsolationForest().setMaxSamples(256).setMaxFeatures(1.0).setNumTrees(100); + IsolationForest loadedIsolationForest = + TestUtils.saveAndReload( + tEnv, + isolationForest, + tempFolder.newFolder().getAbsolutePath(), + IsolationForest::load); + + IsolationForestModel model = loadedIsolationForest.fit(dataTable); + IsolationForestModel loadedModel = + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + IsolationForestModel::load); + Table output = loadedModel.transform(dataTable)[0]; + assertEquals( + Arrays.asList("iForest"), + loadedModel.getModelData()[0].getResolvedSchema().getColumnNames()); + assertEquals( + Arrays.asList("features", "prediction"), + output.getResolvedSchema().getColumnNames()); + + List results = IteratorUtils.toList(output.execute().collect()); + List> actualGroups = + groupFeaturesByPrediction( + results, + isolationForest.getFeaturesCol(), + isolationForest.getPredictionCol()); + assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testGetModelData() throws Exception { + IsolationForest isolationForest = + new IsolationForest().setMaxSamples(256).setMaxFeatures(1.0).setNumTrees(100); + IsolationForestModel model = isolationForest.fit(dataTable); + assertEquals( + Arrays.asList("iForest"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + + DataStream modelData = + IsolationForestModelData.getModelDataStream(model.getModelData()[0]); + List collectedModelData = + IteratorUtils.toList(modelData.executeAndCollect()); + + IForest iForest = collectedModelData.get(0).iForest; + + if (iForest.center0 < 0.5) { + throw new Exception("Predicted value and actual value are different."); + } + + if (iForest.center1 > 0.5) { + throw new Exception("Predicted value and actual value are different."); + } + } + + @Test + public void testSetModelData() throws Exception { + IsolationForest isolationForest = + new IsolationForest().setMaxSamples(256).setMaxFeatures(1.0).setNumTrees(100); + IsolationForestModel modelA = isolationForest.fit(dataTable); + IsolationForestModel modelB = + new IsolationForestModel().setModelData(modelA.getModelData()); + ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); + + Table output = modelB.transform(dataTable)[0]; + List results = IteratorUtils.toList(output.execute().collect()); + List> actualGroups = + groupFeaturesByPrediction( + results, + isolationForest.getFeaturesCol(), + isolationForest.getPredictionCol()); + assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } +}