Skip to content

[FLINK-33003] Support isolation forest algorithm in Flink ML #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.examples.anomalydetection;

import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForest;
import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForestModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates an IsolationForest instance and uses it for anomaly detection. */
public class IsolationForestExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

// Generates input data.
DataStream<DenseVector> inputStream =
env.fromElements(
Vectors.dense(1, 2),
Vectors.dense(1.1, 2),
Vectors.dense(1, 2.1),
Vectors.dense(1.1, 2.1),
Vectors.dense(0.1, 0.1));

Table inputTable = tEnv.fromDataStream(inputStream).as("features");

IsolationForest isolationForest =
new IsolationForest()
.setNumTrees(100)
.setMaxIter(10)
.setMaxSamples(256)
.setMaxFeatures(1.0);

IsolationForestModel isolationForestModel = isolationForest.fit(inputTable);

Table outputTable = isolationForestModel.transform(inputTable)[0];

// Extracts and displays the results.
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features = (DenseVector) row.getField(isolationForest.getFeaturesCol());
int predictId = (Integer) row.getField(isolationForest.getPredictionCol());
System.out.printf("Features: %s \tPrediction: %s\n", features, predictId);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.anomalydetection.isolationforest;

import org.apache.flink.ml.linalg.DenseVector;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/** Construct isolation forest. */
public class IForest implements Serializable {
public final int numTrees;
public List<ITree> iTreeList;
public Double center0;
public Double center1;
public int subSamplesSize;

public IForest(int numTrees) {
this.numTrees = numTrees;
this.iTreeList = new ArrayList<>(256);
this.center0 = null;
this.center1 = null;
}

public void generateIsolationForest(DenseVector[] samplesData, int[] featureIndices) {
int n = samplesData.length;
this.subSamplesSize = Math.min(256, n);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits: would it be more consistent with existing code to use subSamplesSize directly here?

Same for other usages of this.* outside constructor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will fix it.

int limitHeight = (int) Math.ceil(Math.log(Math.max(subSamplesSize, 2)) / Math.log(2));
Random randomState = new Random(System.nanoTime());
for (int i = 0; i < numTrees; i++) {
DenseVector[] subSamples = new DenseVector[subSamplesSize];
for (int j = 0; j < subSamplesSize; j++) {
int r = randomState.nextInt(n);
subSamples[j] = samplesData[r];
}
ITree isolationTree =
ITree.generateIsolationTree(
subSamples, 0, limitHeight, randomState, featureIndices);
this.iTreeList.add(isolationTree);
}
}

public DenseVector calculateScore(DenseVector[] samplesData) throws Exception {
int n = samplesData.length;
DenseVector score = new DenseVector(n);
for (int i = 0; i < n; i++) {
double pathLengthSum = 0;
for (ITree isolationTree : iTreeList) {
pathLengthSum += ITree.calculatePathLength(samplesData[i], isolationTree);
}

double pathLengthAvg = pathLengthSum / iTreeList.size();
double cn = ITree.calculateCn(subSamplesSize);
double index = pathLengthAvg / cn;
score.set(i, Math.pow(2, -index));
}

return score;
}

public DenseVector classifyByCluster(DenseVector score, int iters) {
int scoresSize = score.size();
this.center0 = score.get(0); // Cluster center of abnormal
this.center1 = score.get(0); // Cluster center of normal

for (int p = 1; p < score.size(); p++) {
if (score.get(p) > center0) {
center0 = score.get(p);
}

if (score.get(p) < center1) {
center1 = score.get(p);
}
}

int cnt0, cnt1;
double diff0, diff1;
int[] labels = new int[scoresSize];

for (int i = 0; i < iters; i++) {
cnt0 = 0;
cnt1 = 0;

for (int j = 0; j < scoresSize; j++) {
diff0 = Math.abs(score.get(j) - center0);
diff1 = Math.abs(score.get(j) - center1);

if (diff0 < diff1) {
labels[j] = 0;
cnt0++;
} else {
labels[j] = 1;
cnt1++;
}
}

diff0 = center0;
diff1 = center1;

center0 = 0.0;
center1 = 0.0;
for (int k = 0; k < scoresSize; k++) {
if (labels[k] == 0) {
center0 += score.get(k);
} else {
center1 += score.get(k);
}
}

center0 /= cnt0;
center1 /= cnt1;

if (center0 - diff0 <= 1e-6 && center1 - diff1 <= 1e-6) {
break;
}
}
return new DenseVector(new double[] {center0, center1});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.anomalydetection.isolationforest;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.DenseVector;

import java.io.Serializable;
import java.util.Random;

/** Construct isolation tree. */
public class ITree implements Serializable {
public final int attributeIndex;
public final double splitAttributeValue;
public ITree leftTree;
public ITree rightTree;
public int currentHeight;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we only mutate currentHeight right after it is constructed, it should be possible and better to make it final.

Same for the other three non-final variables.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good for me, I will make it final.

public int leafNodesNum;

public ITree(int attributeIndex, double splitAttributeValue) {
this.attributeIndex = attributeIndex;
this.splitAttributeValue = splitAttributeValue;
this.leftTree = null;
this.rightTree = null;
this.currentHeight = 0;
this.leafNodesNum = 1;
}

public static ITree generateIsolationTree(
DenseVector[] samplesData,
int currentHeight,
int limitHeight,
Random randomState,
int[] featureIndices) {
int n = samplesData.length;
ITree isolationTree;
if (samplesData.length == 0) {
return null;
} else if (samplesData.length == 1 || currentHeight >= limitHeight) {
isolationTree = new ITree(0, samplesData[0].get(0));
isolationTree.currentHeight = currentHeight;
isolationTree.leafNodesNum = samplesData.length;
return isolationTree;
}
boolean flag = true;
for (int i = 1; i < n; i++) {
if (!samplesData[i].equals(samplesData[i - 1])) {
flag = false;
break;
}
}
if (flag) {
isolationTree = new ITree(0, samplesData[0].get(0));
isolationTree.currentHeight = currentHeight;
isolationTree.leafNodesNum = samplesData.length;
return isolationTree;
}

Tuple2<Integer, Double> tuple2 =
getRandomFeatureToSplit(samplesData, randomState, featureIndices);
int attributeIndex = tuple2.f0;
double splitAttributeValue = tuple2.f1;

int leftNodesNum = 0;
int rightNodesNum = 0;
for (DenseVector datum : samplesData) {
if (datum.get(attributeIndex) < splitAttributeValue) {
leftNodesNum++;
} else {
rightNodesNum++;
}
}

DenseVector[] leftSamples = new DenseVector[leftNodesNum];
DenseVector[] rightSamples = new DenseVector[rightNodesNum];
int l = 0, r = 0;
for (DenseVector samplesDatum : samplesData) {
if (samplesDatum.get(attributeIndex) < splitAttributeValue) {
leftSamples[l++] = samplesDatum;
} else {
rightSamples[r++] = samplesDatum;
}
}

ITree root = new ITree(attributeIndex, splitAttributeValue);
root.currentHeight = currentHeight;
root.leafNodesNum = samplesData.length;
root.leftTree =
generateIsolationTree(
leftSamples, currentHeight + 1, limitHeight, randomState, featureIndices);
root.rightTree =
generateIsolationTree(
rightSamples, currentHeight + 1, limitHeight, randomState, featureIndices);

return root;
}

private static Tuple2<Integer, Double> getRandomFeatureToSplit(
DenseVector[] samplesData, Random randomState, int[] featureIndices) {
int attributeIndex = featureIndices[randomState.nextInt(featureIndices.length)];

double maxValue = samplesData[0].get(attributeIndex);
double minValue = samplesData[0].get(attributeIndex);
for (int i = 1; i < samplesData.length; i++) {
minValue = Math.min(minValue, samplesData[i].get(attributeIndex));
maxValue = Math.max(maxValue, samplesData[i].get(attributeIndex));
}
double splitAttributeValue = (maxValue - minValue) * randomState.nextDouble() + minValue;

return Tuple2.of(attributeIndex, splitAttributeValue);
}

public static double calculatePathLength(DenseVector sampleData, ITree isolationTree)
throws Exception {
double pathLength = -1;
ITree tmpITree = isolationTree;
while (tmpITree != null) {
pathLength += 1;
if (tmpITree.leftTree == null
|| tmpITree.rightTree == null
|| sampleData.get(tmpITree.attributeIndex) == tmpITree.splitAttributeValue) {
break;
} else if (sampleData.get(tmpITree.attributeIndex) < tmpITree.splitAttributeValue) {
tmpITree = tmpITree.leftTree;
} else {
tmpITree = tmpITree.rightTree;
}
}

assert tmpITree != null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we only use assert in unit tests.

Given that we will have NullPointerException in the line below if tmpITree == null, it seems simpler to just remove this line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will delete it.

return pathLength + calculateCn(tmpITree.leafNodesNum);
}

public static double calculateCn(double n) {
if (n <= 1) {
return 0;
}
return 2.0 * (Math.log(n - 1.0) + 0.5772156649015329) - 2.0 * (n - 1.0) / n;
}
}
Loading