Skip to content

Commit

Permalink
Fix Confidence Adjustment for Larger Shingle Sizes (#407)
Browse files Browse the repository at this point in the history
* Fix Confidence Adjustment for Larger Shingle Sizes

This PR addresses further adjustments to the confidence calculation issue discussed in PR 405. While PR 405 successfully resolved the issue for a shingle size of 4, it did not achieve the same results for larger shingle sizes like 8.

Key Changes
1. Refinement of seenValues Calculation:
* Previously, the formula increased confidence even as numImputed (number of imputations seen) increased because seenValues (all values seen) also increased.
* This PR fixes the issue by counting only non-imputed values as seenValues.
2. Upper Bound for numImputed:
* The numImputed is now upper bounded to the shingle size.
* The impute fraction calculation, which uses numberOfImputed * 1.0 / shingleSize, now ensures the fraction does not exceed 1.
3. Decrementing numberOfImputed:
* The numberOfImputed is decremented when there is no imputation.
* Previously, numberOfImputed remained unchanged when there is an imputation as there was both an increment and a decrement, keeping the imputation fraction constant. This PR ensures the imputation fraction accurately reflects the current state. This adjustment ensures that the forest update decision, which relies on the imputation fraction, functions correctly. The forest is updated only when the imputation fraction is below the threshold of 0.5.

Testing
* Added test scenarios with various shingle sizes to verify the changes.

Signed-off-by: Kaituo Li <[email protected]>

* added comment

Signed-off-by: Kaituo Li <[email protected]>

---------

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo authored Aug 1, 2024
1 parent 07aab4a commit f2984b5
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,44 +80,6 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
return point;
}

/**
* the timestamps are now used to calculate the number of imputed tuples in the
* shingle
*
* @param timestamp the timestamp of the current input
*/
@Override
protected void updateTimestamps(long timestamp) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the first
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp. However, there are scenarios where we might miss
* decrementing numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. The second condition <pre> timestamp >
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
* will decrement numberOfImputed when we move to a new timestamp, provided
* numberOfImputed is greater than zero.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* decides if the forest should be updated, this is needed for imputation on the
* fly. The main goal of this function is to avoid runaway sequences where a
Expand All @@ -128,7 +90,10 @@ protected void updateTimestamps(long timestamp) {
*/
protected boolean updateAllowed() {
double fraction = numberOfImputed * 1.0 / (shingleSize);
if (numberOfImputed == shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
if (fraction > 1) {
fraction = 1;
}
if (numberOfImputed >= shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
&& (transformMethod == DIFFERENCE || transformMethod == NORMALIZE_DIFFERENCE)) {
// this shingle is disconnected from the previously seen values
// these transformations will have little meaning
Expand All @@ -144,10 +109,57 @@ protected boolean updateAllowed() {
// two different points).
return false;
}

dataQuality[0].update(1 - fraction);
return (fraction < useImputedFraction && internalTimeStamp >= shingleSize);
}

@Override
protected void updateTimestamps(long timestamp) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp.
*
* For example, if the last 10 values were imputed and the shingle size is 8,
* the condition will most likely return false until all 10 imputed values are
* removed from the shingle.
*
* However, there are scenarios where we might miss decrementing
* numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. To address the issue, we add code in method
* updateForest to decrement numberOfImputed when we move to a new timestamp,
* provided there is no imputation. This ensures th e imputation fraction does
* not increase as long as the imputation is continuing. This also ensures that
* the forest update decision, which relies on the imputation fraction,
* functions correctly. The forest is updated only when the imputation fraction
* is below the threshold of 0.5.
*
* Also, why can't we combine the decrement code between updateTimestamps and
* updateForest together? This would cause Consistency.ImputeTest to fail when
* testing with and without imputation, as the RCF scores would not change. The
* method updateTimestamps is used in other places (e.g., updateState and
* dischargeInitial), not only in updateForest.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* the following function mutates the forest, the lastShingledPoint,
* lastShingledInput as well as previousTimeStamps, and adds the shingled input
Expand All @@ -168,7 +180,13 @@ void updateForest(boolean changeForest, double[] input, long timestamp, RandomCu
updateShingle(input, scaledInput);
updateTimestamps(timestamp);
if (isFullyImputed) {
numberOfImputed = numberOfImputed + 1;
// The numImputed is now capped at the shingle size to ensure that the impute
// fraction,
// calculated as numberOfImputed * 1.0 / shingleSize, does not exceed 1.
numberOfImputed = Math.min(numberOfImputed + 1, shingleSize);
} else if (numberOfImputed > 0) {
// Decrement numberOfImputed when the new value is not imputed
numberOfImputed = numberOfImputed - 1;
}
if (changeForest) {
if (forest.isInternalShinglingEnabled()) {
Expand All @@ -190,7 +208,14 @@ public void update(double[] point, float[] rcfPoint, long timestamp, int[] missi
return;
}
generateShingle(point, timestamp, missing, getTimeFactor(timeStampDeviations[1]), true, forest);
++valuesSeen;
// The confidence formula depends on numImputed (the number of recent
// imputations seen)
// and seenValues (all values seen). To ensure confidence decreases when
// numImputed increases,
// we need to count only non-imputed values as seenValues.
if (missing == null || missing.length != point.length) {
++valuesSeen;
}
}

protected double getTimeFactor(Deviation deviation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ public void setLastScore(double[] score) {
}

void validateIgnore(double[] shift, int length) {
checkArgument(shift.length == length, () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
checkArgument(shift.length == length,
() -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
for (double element : shift) {
checkArgument(element >= 0, "has to be non-negative");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,38 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.stream.Stream;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;

public class MissingValueTest {
private static class EnumAndValueProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(ImputationMethod.PREVIOUS, ImputationMethod.ZERO, ImputationMethod.FIXED_VALUES)
.flatMap(method -> Stream.of(4, 8, 16) // Example shingle sizes
.map(shingleSize -> Arguments.of(method, shingleSize)));
}
}

@ParameterizedTest
@EnumSource(ImputationMethod.class)
public void testConfidence(ImputationMethod method) {
@ArgumentsSource(EnumAndValueProvider.class)
public void testConfidence(ImputationMethod method, int shingleSize) {
// Create and populate a random cut forest

int shingleSize = 4;
int numberOfTrees = 50;
int sampleSize = 256;
Precision precision = Precision.FLOAT_32;
Expand All @@ -45,11 +59,19 @@ public void testConfidence(ImputationMethod method) {
long count = 0;

int dimensions = baseDimensions * shingleSize;
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
ThresholdedRandomCutForest.Builder forestBuilder = new ThresholdedRandomCutForest.Builder<>().compact(true)
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();
.forestMode(ForestMode.STREAMING_IMPUTE).transformMethod(TransformMethod.NORMALIZE).autoAdjust(true);

if (method == ImputationMethod.FIXED_VALUES) {
// we cannot pass fillValues when the method is not fixed values. Otherwise, we
// will impute
// filled in values irregardless of imputation method
forestBuilder.fillValues(new double[] { 3 });
}

ThresholdedRandomCutForest forest = forestBuilder.build();

// Define the size and range
int size = 400;
Expand All @@ -75,18 +97,38 @@ public void testConfidence(ImputationMethod method) {
float[] rcfPoint = result.getRCFPoint();
double scale = result.getScale()[0];
double shift = result.getShift()[0];
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
double[] actual = new double[] { (rcfPoint[shingleSize - 1] * scale) + shift };
if (method == ImputationMethod.ZERO) {
assertEquals(0, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.FIXED_VALUES) {
assertEquals(3.0d, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.PREVIOUS) {
assertEquals(0, result.getAnomalyGrade(), 0.001d,
"count: " + count + " actual: " + Arrays.toString(actual));
}
} else {
AnomalyDescriptor result = forest.process(point, newStamp);
if ((count > 100 && count < 300) || count >= 326) {
// after 325, we have a period of confidence decreasing. After that, confidence
// starts increasing again.
// We are not sure where the confidence will start increasing after decreasing.
// So we start check the behavior after 325 + shingleSize.
int backupPoint = 325 + shingleSize;
if ((count > 100 && count < 300) || count >= backupPoint) {
// The first 65+ observations gives 0 confidence.
// Confidence start increasing after 1 observed point
assertTrue(result.getDataConfidence() > lastConfidence);
assertTrue(result.getDataConfidence() > lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
} else if (count < 325 && count > 300) {
assertTrue(result.getDataConfidence() < lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
}
lastConfidence = result.getDataConfidence();
}
Expand Down

0 comments on commit f2984b5

Please sign in to comment.