-
Notifications
You must be signed in to change notification settings - Fork 94
[FLINK-33003] Support isolation forest algorithm in Flink ML #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.examples.anomalydetection; | ||
|
||
import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForest; | ||
import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForestModel; | ||
import org.apache.flink.ml.linalg.DenseVector; | ||
import org.apache.flink.ml.linalg.Vectors; | ||
import org.apache.flink.streaming.api.datastream.DataStream; | ||
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; | ||
import org.apache.flink.table.api.Table; | ||
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; | ||
import org.apache.flink.types.Row; | ||
import org.apache.flink.util.CloseableIterator; | ||
|
||
/** Simple program that creates an IsolationForest instance and uses it for anomaly detection. */ | ||
public class IsolationForestExample { | ||
public static void main(String[] args) { | ||
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); | ||
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); | ||
|
||
// Generates input data. | ||
DataStream<DenseVector> inputStream = | ||
env.fromElements( | ||
Vectors.dense(1, 2), | ||
Vectors.dense(1.1, 2), | ||
Vectors.dense(1, 2.1), | ||
Vectors.dense(1.1, 2.1), | ||
Vectors.dense(0.1, 0.1)); | ||
|
||
Table inputTable = tEnv.fromDataStream(inputStream).as("features"); | ||
|
||
IsolationForest isolationForest = | ||
new IsolationForest() | ||
.setNumTrees(100) | ||
.setMaxIter(10) | ||
.setMaxSamples(256) | ||
.setMaxFeatures(1.0); | ||
|
||
IsolationForestModel isolationForestModel = isolationForest.fit(inputTable); | ||
|
||
Table outputTable = isolationForestModel.transform(inputTable)[0]; | ||
|
||
// Extracts and displays the results. | ||
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { | ||
Row row = it.next(); | ||
DenseVector features = (DenseVector) row.getField(isolationForest.getFeaturesCol()); | ||
int predictId = (Integer) row.getField(isolationForest.getPredictionCol()); | ||
System.out.printf("Features: %s \tPrediction: %s\n", features, predictId); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.anomalydetection.isolationforest; | ||
|
||
import org.apache.flink.ml.linalg.DenseVector; | ||
|
||
import java.io.Serializable; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Random; | ||
|
||
/** Construct isolation forest. */ | ||
public class IForest implements Serializable { | ||
public final int numTrees; | ||
public List<ITree> iTreeList; | ||
public Double center0; | ||
public Double center1; | ||
public int subSamplesSize; | ||
|
||
public IForest(int numTrees) { | ||
this.numTrees = numTrees; | ||
this.iTreeList = new ArrayList<>(256); | ||
this.center0 = null; | ||
this.center1 = null; | ||
} | ||
|
||
public void generateIsolationForest(DenseVector[] samplesData, int[] featureIndices) { | ||
int n = samplesData.length; | ||
this.subSamplesSize = Math.min(256, n); | ||
int limitHeight = (int) Math.ceil(Math.log(Math.max(subSamplesSize, 2)) / Math.log(2)); | ||
Random randomState = new Random(System.nanoTime()); | ||
for (int i = 0; i < numTrees; i++) { | ||
DenseVector[] subSamples = new DenseVector[subSamplesSize]; | ||
for (int j = 0; j < subSamplesSize; j++) { | ||
int r = randomState.nextInt(n); | ||
subSamples[j] = samplesData[r]; | ||
} | ||
ITree isolationTree = | ||
ITree.generateIsolationTree( | ||
subSamples, 0, limitHeight, randomState, featureIndices); | ||
this.iTreeList.add(isolationTree); | ||
} | ||
} | ||
|
||
public DenseVector calculateScore(DenseVector[] samplesData) throws Exception { | ||
int n = samplesData.length; | ||
DenseVector score = new DenseVector(n); | ||
for (int i = 0; i < n; i++) { | ||
double pathLengthSum = 0; | ||
for (ITree isolationTree : iTreeList) { | ||
pathLengthSum += ITree.calculatePathLength(samplesData[i], isolationTree); | ||
} | ||
|
||
double pathLengthAvg = pathLengthSum / iTreeList.size(); | ||
double cn = ITree.calculateCn(subSamplesSize); | ||
double index = pathLengthAvg / cn; | ||
score.set(i, Math.pow(2, -index)); | ||
} | ||
|
||
return score; | ||
} | ||
|
||
public DenseVector classifyByCluster(DenseVector score, int iters) { | ||
int scoresSize = score.size(); | ||
this.center0 = score.get(0); // Cluster center of abnormal | ||
this.center1 = score.get(0); // Cluster center of normal | ||
|
||
for (int p = 1; p < score.size(); p++) { | ||
if (score.get(p) > center0) { | ||
center0 = score.get(p); | ||
} | ||
|
||
if (score.get(p) < center1) { | ||
center1 = score.get(p); | ||
} | ||
} | ||
|
||
int cnt0, cnt1; | ||
double diff0, diff1; | ||
int[] labels = new int[scoresSize]; | ||
|
||
for (int i = 0; i < iters; i++) { | ||
cnt0 = 0; | ||
cnt1 = 0; | ||
|
||
for (int j = 0; j < scoresSize; j++) { | ||
diff0 = Math.abs(score.get(j) - center0); | ||
diff1 = Math.abs(score.get(j) - center1); | ||
|
||
if (diff0 < diff1) { | ||
labels[j] = 0; | ||
cnt0++; | ||
} else { | ||
labels[j] = 1; | ||
cnt1++; | ||
} | ||
} | ||
|
||
diff0 = center0; | ||
diff1 = center1; | ||
|
||
center0 = 0.0; | ||
center1 = 0.0; | ||
for (int k = 0; k < scoresSize; k++) { | ||
if (labels[k] == 0) { | ||
center0 += score.get(k); | ||
} else { | ||
center1 += score.get(k); | ||
} | ||
} | ||
|
||
center0 /= cnt0; | ||
center1 /= cnt1; | ||
|
||
if (center0 - diff0 <= 1e-6 && center1 - diff1 <= 1e-6) { | ||
break; | ||
} | ||
} | ||
return new DenseVector(new double[] {center0, center1}); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.anomalydetection.isolationforest; | ||
|
||
import org.apache.flink.api.java.tuple.Tuple2; | ||
import org.apache.flink.ml.linalg.DenseVector; | ||
|
||
import java.io.Serializable; | ||
import java.util.Random; | ||
|
||
/** Construct isolation tree. */ | ||
public class ITree implements Serializable { | ||
public final int attributeIndex; | ||
public final double splitAttributeValue; | ||
public ITree leftTree; | ||
public ITree rightTree; | ||
public int currentHeight; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that we only mutate Same for the other three non-final variables. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good for me, I will make it final. |
||
public int leafNodesNum; | ||
|
||
public ITree(int attributeIndex, double splitAttributeValue) { | ||
this.attributeIndex = attributeIndex; | ||
this.splitAttributeValue = splitAttributeValue; | ||
this.leftTree = null; | ||
this.rightTree = null; | ||
this.currentHeight = 0; | ||
this.leafNodesNum = 1; | ||
} | ||
|
||
public static ITree generateIsolationTree( | ||
DenseVector[] samplesData, | ||
int currentHeight, | ||
int limitHeight, | ||
Random randomState, | ||
int[] featureIndices) { | ||
int n = samplesData.length; | ||
ITree isolationTree; | ||
if (samplesData.length == 0) { | ||
return null; | ||
} else if (samplesData.length == 1 || currentHeight >= limitHeight) { | ||
isolationTree = new ITree(0, samplesData[0].get(0)); | ||
isolationTree.currentHeight = currentHeight; | ||
isolationTree.leafNodesNum = samplesData.length; | ||
return isolationTree; | ||
} | ||
boolean flag = true; | ||
for (int i = 1; i < n; i++) { | ||
if (!samplesData[i].equals(samplesData[i - 1])) { | ||
flag = false; | ||
break; | ||
} | ||
} | ||
if (flag) { | ||
isolationTree = new ITree(0, samplesData[0].get(0)); | ||
isolationTree.currentHeight = currentHeight; | ||
isolationTree.leafNodesNum = samplesData.length; | ||
return isolationTree; | ||
} | ||
|
||
Tuple2<Integer, Double> tuple2 = | ||
getRandomFeatureToSplit(samplesData, randomState, featureIndices); | ||
int attributeIndex = tuple2.f0; | ||
double splitAttributeValue = tuple2.f1; | ||
|
||
int leftNodesNum = 0; | ||
int rightNodesNum = 0; | ||
for (DenseVector datum : samplesData) { | ||
if (datum.get(attributeIndex) < splitAttributeValue) { | ||
leftNodesNum++; | ||
} else { | ||
rightNodesNum++; | ||
} | ||
} | ||
|
||
DenseVector[] leftSamples = new DenseVector[leftNodesNum]; | ||
DenseVector[] rightSamples = new DenseVector[rightNodesNum]; | ||
int l = 0, r = 0; | ||
for (DenseVector samplesDatum : samplesData) { | ||
if (samplesDatum.get(attributeIndex) < splitAttributeValue) { | ||
leftSamples[l++] = samplesDatum; | ||
} else { | ||
rightSamples[r++] = samplesDatum; | ||
} | ||
} | ||
|
||
ITree root = new ITree(attributeIndex, splitAttributeValue); | ||
root.currentHeight = currentHeight; | ||
root.leafNodesNum = samplesData.length; | ||
root.leftTree = | ||
generateIsolationTree( | ||
leftSamples, currentHeight + 1, limitHeight, randomState, featureIndices); | ||
root.rightTree = | ||
generateIsolationTree( | ||
rightSamples, currentHeight + 1, limitHeight, randomState, featureIndices); | ||
|
||
return root; | ||
} | ||
|
||
private static Tuple2<Integer, Double> getRandomFeatureToSplit( | ||
DenseVector[] samplesData, Random randomState, int[] featureIndices) { | ||
int attributeIndex = featureIndices[randomState.nextInt(featureIndices.length)]; | ||
|
||
double maxValue = samplesData[0].get(attributeIndex); | ||
double minValue = samplesData[0].get(attributeIndex); | ||
for (int i = 1; i < samplesData.length; i++) { | ||
minValue = Math.min(minValue, samplesData[i].get(attributeIndex)); | ||
maxValue = Math.max(maxValue, samplesData[i].get(attributeIndex)); | ||
} | ||
double splitAttributeValue = (maxValue - minValue) * randomState.nextDouble() + minValue; | ||
|
||
return Tuple2.of(attributeIndex, splitAttributeValue); | ||
} | ||
|
||
public static double calculatePathLength(DenseVector sampleData, ITree isolationTree) | ||
throws Exception { | ||
double pathLength = -1; | ||
ITree tmpITree = isolationTree; | ||
while (tmpITree != null) { | ||
pathLength += 1; | ||
if (tmpITree.leftTree == null | ||
|| tmpITree.rightTree == null | ||
|| sampleData.get(tmpITree.attributeIndex) == tmpITree.splitAttributeValue) { | ||
break; | ||
} else if (sampleData.get(tmpITree.attributeIndex) < tmpITree.splitAttributeValue) { | ||
tmpITree = tmpITree.leftTree; | ||
} else { | ||
tmpITree = tmpITree.rightTree; | ||
} | ||
} | ||
|
||
assert tmpITree != null; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general we only use Given that we will have NullPointerException in the line below if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will delete it. |
||
return pathLength + calculateCn(tmpITree.leafNodesNum); | ||
} | ||
|
||
public static double calculateCn(double n) { | ||
if (n <= 1) { | ||
return 0; | ||
} | ||
return 2.0 * (Math.log(n - 1.0) + 0.5772156649015329) - 2.0 * (n - 1.0) / n; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nits: would it be more consistent with existing code to use
subSamplesSize
directly here?Same for other usages of
this.*
outside constructor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will fix it.