diff --git a/src/example/org/deidentifier/arx/examples/Example61.java b/src/example/org/deidentifier/arx/examples/Example61.java new file mode 100644 index 0000000000..710244d5af --- /dev/null +++ b/src/example/org/deidentifier/arx/examples/Example61.java @@ -0,0 +1,192 @@ +/* + * ARX: Powerful Data Anonymization + * Copyright 2012 - 2021 Fabian Prasser and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.deidentifier.arx.examples; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.deidentifier.arx.ARXAnonymizer; +import org.deidentifier.arx.ARXClassificationConfiguration; +import org.deidentifier.arx.ARXConfiguration; +import org.deidentifier.arx.ARXResult; +import org.deidentifier.arx.AttributeType; +import org.deidentifier.arx.AttributeType.Hierarchy; +import org.deidentifier.arx.Data; +import org.deidentifier.arx.DataHandle; +import org.deidentifier.arx.DataSubset; +import org.deidentifier.arx.DataType; +import org.deidentifier.arx.aggregates.ClassificationConfigurationLogisticRegression; +import org.deidentifier.arx.criteria.Inclusion; +import org.deidentifier.arx.criteria.KAnonymity; +import org.deidentifier.arx.io.CSVHierarchyInput; +import org.deidentifier.arx.metric.Metric; + +/** + * This class implements an example on how to compare data mining performance + * using a training and a test set + * @author Fabian Prasser + * @author Ibhraheem Al-Dhamari + */ +public class Example61 extends Example { + + /** + * Loads a dataset from disk + * @param dataset + * @return + * @throws IOException + */ + public static Data createData(final String dataset) throws IOException { + + // Load data + Data data = Data.create("data/" + dataset + ".csv", StandardCharsets.UTF_8, ';'); + + // Read generalization hierarchies + FilenameFilter hierarchyFilter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + if (name.matches(dataset + "_hierarchy_(.)+.csv")) { + return true; + } else { + return false; + } + } + }; + + // Create definition + File testDir = new File("data/"); + File[] genHierFiles = testDir.listFiles(hierarchyFilter); + Pattern pattern = Pattern.compile("_hierarchy_(.*?).csv"); + for (File file : genHierFiles) { + Matcher matcher = pattern.matcher(file.getName()); + if (matcher.find()) { + CSVHierarchyInput hier = new CSVHierarchyInput(file, StandardCharsets.UTF_8, ';'); + String attributeName = matcher.group(1); + data.getDefinition().setAttributeType(attributeName, Hierarchy.create(hier.getHierarchy())); + } + } + + return data; + } + + /** + * Gets a set of random record indices for this dataset + * @param data + * @param sampleFraction + * @return + */ + public static Set getRandomSample(Data data, double sampleFraction) { + + // Create list + int rows = data.getHandle().getNumRows(); + List list = new ArrayList<>(); + for (int i = 0; i < rows; ++i) { + list.add(i); + } + + // Shuffle + Collections.shuffle(list, new Random(0xDEADBEEF)); + + // Select sample and create set + return new HashSet(list.subList(0, (int) Math.round((double) rows * sampleFraction))); + } + + /** + * Entry point. + * + * @param args the arguments + * @throws ParseException + * @throws IOException + */ + public static void main(String[] args) throws ParseException, IOException { + + Data data = createData("adult"); + data.getDefinition().setAttributeType("marital-status", AttributeType.INSENSITIVE_ATTRIBUTE); + data.getDefinition().setDataType("age", DataType.INTEGER); + data.getDefinition().setResponseVariable("marital-status", true); + + // Size of training set + double trainingSetSize = 0.8d; + + // Create sample + Set trainingSetIndices = getRandomSample(data, trainingSetSize); + DataSubset trainingSet = DataSubset.create(data, trainingSetIndices); + + // Configure anonymization + ARXAnonymizer anonymizer = new ARXAnonymizer(); + ARXConfiguration config = ARXConfiguration.create(); + config.addPrivacyModel(new KAnonymity(5)); + config.addPrivacyModel(new Inclusion(trainingSet)); + config.setSuppressionLimit(1d); + config.setQualityModel(Metric.createClassificationMetric()); + + // Start anonymization process + ARXResult result = anonymizer.anonymize(data, config); + DataHandle output = result.getOutput(); + + // Run evaluation using k-fold cross validation + System.out.println("----------------------------------------"); + System.out.println("Evaluation using k-fold cross validation"); + System.out.println("----------------------------------------"); + evaluate(output, false); + + // Run evaluation using test/training set + System.out.println("--------------------------------------"); + System.out.println("Evaluation using test and training set"); + System.out.println("--------------------------------------"); + evaluate(output, true); + } + + /** + * Run evaluations + * @param data + * @param useTestTrainingSet + * @throws ParseException + */ + private static void evaluate(DataHandle data, boolean useTestTrainingSet) throws ParseException { + + // Specify + String[] features = new String[] { + "sex", + "age", + "race", + "education", + "native-country", + "workclass", + "occupation", + "salary-class" + }; + + String clazz = "marital-status"; + + // Perform measurement + ClassificationConfigurationLogisticRegression logisticClassifier = ARXClassificationConfiguration.createLogisticRegression(); + logisticClassifier.setUseTrainingTestSet(useTestTrainingSet); + System.out.println(data.getStatistics().getClassificationPerformance(features, clazz, logisticClassifier)); + } +} \ No newline at end of file diff --git a/src/gui/org/deidentifier/arx/gui/model/ModelClassification.java b/src/gui/org/deidentifier/arx/gui/model/ModelClassification.java index a6622dc46b..1e68603531 100644 --- a/src/gui/org/deidentifier/arx/gui/model/ModelClassification.java +++ b/src/gui/org/deidentifier/arx/gui/model/ModelClassification.java @@ -119,6 +119,11 @@ public boolean isModified() { * @param configCurrent */ public void setCurrentConfiguration(ARXClassificationConfiguration configCurrent){ + if (!(configCurrent == config || + configCurrent == configNaiveBayes || + configCurrent == configRandomForest )) { + throw new IllegalArgumentException("Unknown configuration object"); + } this.configCurrent = configCurrent; } @@ -143,7 +148,7 @@ public void setMaxRecords(Integer t) { this.configRandomForest.setMaxRecords(t); this.setModified(); } - + /** * TODO: Ugly hack to set base-parameters for all methods * @param t @@ -154,7 +159,7 @@ public void setNumFolds(Integer t) { this.configRandomForest.setNumFolds(t); this.setModified(); } - + /** * Sets a feature scaling function * @param attribute @@ -175,6 +180,17 @@ public void setUnmodified() { getRandomForestConfiguration().setUnmodified(); } + /** + * TODO: Ugly hack to set base-parameters for all methods + * @param value + */ + public void setUseTrainingTestSet(boolean value) { + this.config.setUseTrainingTestSet(value); + this.configNaiveBayes.setUseTrainingTestSet(value); + this.configRandomForest.setUseTrainingTestSet(value); + this.setModified(); + } + /** * TODO: Ugly hack to set base-parameters for all methods * @param t diff --git a/src/gui/org/deidentifier/arx/gui/resources/crossKFold.png b/src/gui/org/deidentifier/arx/gui/resources/crossKFold.png new file mode 100644 index 0000000000..eb608e2553 Binary files /dev/null and b/src/gui/org/deidentifier/arx/gui/resources/crossKFold.png differ diff --git a/src/gui/org/deidentifier/arx/gui/resources/messages.properties b/src/gui/org/deidentifier/arx/gui/resources/messages.properties index e6584b9084..2512136ee0 100644 --- a/src/gui/org/deidentifier/arx/gui/resources/messages.properties +++ b/src/gui/org/deidentifier/arx/gui/resources/messages.properties @@ -1048,6 +1048,7 @@ DialogProperties.19=Vector length DialogProperties.2=Performance DialogProperties.20=Prior function DialogProperties.21=Configuration +DialogProperties.22=Use test and training set DialogProperties.3=Visualization DialogProperties.4=Default DialogProperties.5=Metadata diff --git a/src/gui/org/deidentifier/arx/gui/resources/tickKFold.png b/src/gui/org/deidentifier/arx/gui/resources/tickKFold.png new file mode 100644 index 0000000000..bc8be08c78 Binary files /dev/null and b/src/gui/org/deidentifier/arx/gui/resources/tickKFold.png differ diff --git a/src/gui/org/deidentifier/arx/gui/view/impl/menu/DialogProperties.java b/src/gui/org/deidentifier/arx/gui/view/impl/menu/DialogProperties.java index 4f3808dbb7..b1b1705b89 100644 --- a/src/gui/org/deidentifier/arx/gui/view/impl/menu/DialogProperties.java +++ b/src/gui/org/deidentifier/arx/gui/view/impl/menu/DialogProperties.java @@ -380,6 +380,10 @@ private void createTabUtility(PreferencesDialog window) { window.addPreference(new PreferenceInteger(Resources.getMessage("DialogProperties.19"), 10, Integer.MAX_VALUE, ARXClassificationConfiguration.DEFAULT_VECTOR_LENGTH) { //$NON-NLS-1$ protected Integer getValue() { return model.getClassificationModel().getCurrentConfiguration().getVectorLength(); } protected void setValue(Object t) { model.getClassificationModel().setVectorLength((Integer)t); }}); + + window.addPreference(new PreferenceBoolean(Resources.getMessage("DialogProperties.22"), ARXClassificationConfiguration.DEFAULT_TEST_TRAINING_SET, true) { //$NON-NLS-1$ + protected Boolean getValue() { return model.getClassificationModel().getCurrentConfiguration().isUseTrainingTestSet(); } + protected void setValue(Object t) { model.getClassificationModel().setUseTrainingTestSet((Boolean)t); }}); } /** diff --git a/src/gui/org/deidentifier/arx/gui/view/impl/utility/ViewStatisticsClassification.java b/src/gui/org/deidentifier/arx/gui/view/impl/utility/ViewStatisticsClassification.java index 93e7cac055..7778247663 100644 --- a/src/gui/org/deidentifier/arx/gui/view/impl/utility/ViewStatisticsClassification.java +++ b/src/gui/org/deidentifier/arx/gui/view/impl/utility/ViewStatisticsClassification.java @@ -25,6 +25,7 @@ import org.deidentifier.arx.ARXClassificationConfiguration; import org.deidentifier.arx.ARXFeatureScaling; +import org.deidentifier.arx.DataHandle; import org.deidentifier.arx.aggregates.StatisticsBuilderInterruptible; import org.deidentifier.arx.aggregates.StatisticsClassification; import org.deidentifier.arx.aggregates.StatisticsClassification.ROCCurve; @@ -999,13 +1000,32 @@ protected void doReset() { @Override protected void doUpdate(final AnalysisContextClassification context) { - // The statistics builder - final StatisticsBuilderInterruptible builder = context.handle.getStatistics().getInterruptibleInstance(); + // Classification configuration final String[] features = context.model.getSelectedFeaturesAsArray(); final String[] targetVariables = context.model.getSelectedClassesAsArray(); - final ARXClassificationConfiguration config = context.model.getClassificationModel().getCurrentConfiguration(); final ARXFeatureScaling scaling = context.model.getClassificationModel().getFeatureScaling(); + // Make sure that an analysis is done through the UI, even if when training/test set is selected and + // the configuration is non-optimal + DataHandle handle = context.handle; + ARXClassificationConfiguration config = context.model.getClassificationModel().getCurrentConfiguration(); + if (config.isUseTrainingTestSet() && !handle.isSubsetAvailable()) { + + // Try to fix by switching to the superset + if (handle.isSupersetAvailable()) { + handle = handle.getSupersetHandle(); + + // Fix by switching to k-fold cross validation + } else { + config = config.clone(); + config.setUseTrainingTestSet(false); + } + } + + // Obtain statistics builder + final StatisticsBuilderInterruptible builder = handle.getStatistics().getInterruptibleInstance(); + final ARXClassificationConfiguration _config = config; + // Break, if nothing do if (context.model.getSelectedFeatures().isEmpty() || context.model.getSelectedClasses().isEmpty()) { @@ -1114,7 +1134,7 @@ public void run() throws InterruptedException { // Compute StatisticsClassification result = builder.getClassificationPerformance(features, targetVariable, - config, + _config, scaling); progress++; if (stopped) { diff --git a/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java b/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java index cd89cf4998..bdce0b393e 100644 --- a/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java +++ b/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java @@ -30,7 +30,18 @@ public abstract class ARXClassificationConfiguration> implements Serializable, Cloneable { /** SVUID */ - private static final long serialVersionUID = -8751059558718015927L; + private static final long serialVersionUID = -8751059558718015927L; + /** Default value */ + public static final boolean DEFAULT_DETERMINISTIC = true; + /** Default value */ + public static final int DEFAULT_MAX_RECORDS = 100000; + /** Default value */ + public static final int DEFAULT_NUMBER_OF_FOLDS = 10; + /** Default value */ + public static final int DEFAULT_VECTOR_LENGTH = 1000; + /** Default value */ + public static final boolean DEFAULT_TEST_TRAINING_SET = false; + /** * Creates a new instance for logistic regression classifiers * @return @@ -52,15 +63,6 @@ public static ClassificationConfigurationNaiveBayes createNaiveBayes() { public static ClassificationConfigurationRandomForest createRandomForest() { return ClassificationConfigurationRandomForest.create(); } - - /** Default value */ - public static final boolean DEFAULT_DETERMINISTIC = true; - /** Default value */ - public static final int DEFAULT_MAX_RECORDS = 100000; - /** Default value */ - public static final int DEFAULT_NUMBER_OF_FOLDS = 10; - /** Default value */ - public static final int DEFAULT_VECTOR_LENGTH = 1000; /** Deterministic */ private boolean deterministic = DEFAULT_DETERMINISTIC; @@ -74,7 +76,9 @@ public static ClassificationConfigurationRandomForest createRandomForest() { private int vectorLength = DEFAULT_VECTOR_LENGTH; /** Modified */ private boolean modified = false; - + /** Training/test set */ + private Boolean useTrainingTestSet = false; + /** * Creates a new instance with default settings */ @@ -89,13 +93,15 @@ public ARXClassificationConfiguration() { * @param numberOfFolds * @param seed * @param vectorLength + * @param useTrainingTestSet */ - protected ARXClassificationConfiguration(boolean deterministic, int maxRecords, int numberOfFolds, long seed, int vectorLength) { + protected ARXClassificationConfiguration(boolean deterministic, int maxRecords, int numberOfFolds, long seed, int vectorLength, boolean useTrainingTestSet) { this.deterministic = deterministic; this.maxRecords = maxRecords; this.numberOfFolds = numberOfFolds; this.seed = seed; this.vectorLength = vectorLength; + this.useTrainingTestSet = useTrainingTestSet; } @Override @@ -145,6 +151,17 @@ public boolean isModified() { return modified; } + /** + * Returns whether to use a training and a test set + * @return + */ + public boolean isUseTrainingTestSet() { + if (this.useTrainingTestSet == null) { + this.useTrainingTestSet = DEFAULT_TEST_TRAINING_SET; + } + return this.useTrainingTestSet; + } + /** * Parses another configuration * @param config @@ -155,6 +172,7 @@ public void parse(ARXClassificationConfiguration config) { this.setNumFolds(config.numberOfFolds); this.setSeed((int)config.seed); this.setVectorLength(config.vectorLength); + this.setUseTrainingTestSet(config.useTrainingTestSet); } /** @@ -228,6 +246,16 @@ public void setUnmodified() { this.modified = false; } + /** + * Sets whether to use a training and a test set + * @param value + */ + @SuppressWarnings("unchecked") + public T setUseTrainingTestSet(boolean value) { + this.useTrainingTestSet = value; + return (T)this; + } + /** * @param vectorLength the vectorLength to set */ diff --git a/src/main/org/deidentifier/arx/DataHandle.java b/src/main/org/deidentifier/arx/DataHandle.java index a3d6d4e60f..2d4e7beba0 100644 --- a/src/main/org/deidentifier/arx/DataHandle.java +++ b/src/main/org/deidentifier/arx/DataHandle.java @@ -85,8 +85,11 @@ public abstract class DataHandle { /** The current registry. */ protected DataRegistry registry = null; - /** The current research subset. */ + /** The current subset. */ protected DataHandle subset = null; + + /** The current superset. */ + protected DataHandle superset = null; /** * Returns the name of the specified column. @@ -538,6 +541,27 @@ public RiskEstimateBuilder getRiskEstimator(ARXPopulationModel model, Set iterator(); - + /** * Releases this handle and all associated resources. If a input handle is released all associated results are released * as well. @@ -1012,7 +1052,7 @@ protected int internalCompare(final int row1, * @return the string */ protected abstract String internalGetValue(int row, int col, boolean ignoreSuppression); - + /** * Returns whether this is an outlier regarding the given columns. If no columns have been * specified, true will be returned. @@ -1039,7 +1079,7 @@ protected int internalCompare(final int row1, protected boolean isAnonymous() { return false; } - + /** * Sets the current header * @param header @@ -1052,6 +1092,7 @@ protected void setHeader(String[] header) { } } + /** * Updates the registry. * @@ -1060,13 +1101,22 @@ protected void setHeader(String[] header) { protected void setRegistry(DataRegistry registry) { this.registry = registry; } - + /** * Sets the subset. * * @param handle the new view */ - protected void setView(DataHandle handle) { + protected void setSubset(DataHandle handle) { subset = handle; } + + /** + * Sets the superset. + * + * @param handle the new view + */ + protected void setSuperset(DataHandle handle) { + superset = handle; + } } diff --git a/src/main/org/deidentifier/arx/DataHandleInternal.java b/src/main/org/deidentifier/arx/DataHandleInternal.java index dfdec3c536..cfe4bca096 100644 --- a/src/main/org/deidentifier/arx/DataHandleInternal.java +++ b/src/main/org/deidentifier/arx/DataHandleInternal.java @@ -230,14 +230,14 @@ public DataHandleInternal getSuperset() { public String getValue(int row, int column) { return handle.getValue(row, column); } - + /** * Gets the value */ public String getValue(final int row, final int col, final boolean ignoreSuppression) { return handle.internalGetValue(row, col, ignoreSuppression); } - + /** * Returns the internal id of the given value * @param column @@ -255,7 +255,7 @@ public int getValueIdentifier(int column, String value) { public DataHandleInternal getView() { return new DataHandleInternal(handle.getView()); } - + /** * Returns whether the handle is anonymous * @return @@ -303,4 +303,12 @@ public boolean isOutput() { return this.handle instanceof DataHandleOutput; } } + + /** + * Returns whether a subset is available + * @return + */ + public boolean isSubsetAvailable() { + return this.getNumRows() != this.getView().getNumRows(); + } } diff --git a/src/main/org/deidentifier/arx/DataRegistry.java b/src/main/org/deidentifier/arx/DataRegistry.java index 7a2d54221a..9a3030f602 100644 --- a/src/main/org/deidentifier/arx/DataRegistry.java +++ b/src/main/org/deidentifier/arx/DataRegistry.java @@ -236,7 +236,10 @@ protected void createInputSubset(ARXConfiguration config){ } else { this.inputSubset = null; } - this.input.setView(this.inputSubset); + this.input.setSubset(this.inputSubset); + if (this.inputSubset != null) { + this.inputSubset.setSuperset(this.input); + } } /** @@ -251,7 +254,10 @@ protected void createOutputSubset(ARXNode node, ARXConfiguration config){ } else { this.outputSubset.remove(node); } - this.output.get(node).setView(this.outputSubset.get(node)); + this.output.get(node).setSubset(this.outputSubset.get(node)); + if (this.outputSubset.get(node) != null) { + this.outputSubset.get(node).setSuperset(this.output.get(node)); + } } /** diff --git a/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationLogisticRegression.java b/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationLogisticRegression.java index 0b315e53fc..db554874b9 100644 --- a/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationLogisticRegression.java +++ b/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationLogisticRegression.java @@ -102,6 +102,7 @@ private ClassificationConfigurationLogisticRegression(){ * @param numberOfFolds * @param deterministic * @param prior + * @param useTrainingTestSet */ protected ClassificationConfigurationLogisticRegression(double alpha, double decayExponent, @@ -113,8 +114,9 @@ protected ClassificationConfigurationLogisticRegression(double alpha, int seed, int numberOfFolds, boolean deterministic, - PriorFunction prior) { - super(deterministic, maxRecords, numberOfFolds, seed, vectorLength); + PriorFunction prior, + boolean useTrainingTestSet) { + super(deterministic, maxRecords, numberOfFolds, seed, vectorLength, useTrainingTestSet); this.alpha = alpha; this.decayExponent = decayExponent; this.lambda = lambda; @@ -140,7 +142,8 @@ public ClassificationConfigurationLogisticRegression clone() { seed, numberOfFolds, deterministic, - prior); + prior, + super.isUseTrainingTestSet()); } /** diff --git a/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationNaiveBayes.java b/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationNaiveBayes.java index 5ea45f6dd6..062c9ca6ea 100644 --- a/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationNaiveBayes.java +++ b/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationNaiveBayes.java @@ -70,6 +70,7 @@ private ClassificationConfigurationNaiveBayes(){ * @param vectorLength * @param type * @param sigma + * @param useTrainingTestSet */ protected ClassificationConfigurationNaiveBayes(boolean deterministic, int maxRecords, @@ -77,8 +78,9 @@ protected ClassificationConfigurationNaiveBayes(boolean deterministic, long seed, int vectorLength, Type type, - double sigma) { - super(deterministic, maxRecords, numberOfFolds, seed, vectorLength); + double sigma, + boolean useTrainingTestSet) { + super(deterministic, maxRecords, numberOfFolds, seed, vectorLength, useTrainingTestSet); this.type = type; this.sigma = sigma; } @@ -91,7 +93,8 @@ public ClassificationConfigurationNaiveBayes clone() { super.getSeed(), super.getVectorLength(), type, - sigma); + sigma, + super.isUseTrainingTestSet()); } /** diff --git a/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationRandomForest.java b/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationRandomForest.java index 4cac0bcfb9..c69ee207b6 100644 --- a/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationRandomForest.java +++ b/src/main/org/deidentifier/arx/aggregates/ClassificationConfigurationRandomForest.java @@ -99,6 +99,7 @@ private ClassificationConfigurationRandomForest(){ * @param maximumNumberOfLeafNodes * @param subsample * @param splitRule + * @param useTrainingTestSet */ protected ClassificationConfigurationRandomForest(boolean deterministic, int maxRecords, @@ -110,8 +111,9 @@ protected ClassificationConfigurationRandomForest(boolean deterministic, int minimumSizeOfLeafNodes, int maximumNumberOfLeafNodes, double subsample, - SplitRule splitRule) { - super(deterministic, maxRecords, numberOfFolds, seed, vectorLength); + SplitRule splitRule, + boolean useTrainingTestSet) { + super(deterministic, maxRecords, numberOfFolds, seed, vectorLength, useTrainingTestSet); this.numberOfTrees = numberOfTrees; this.numberOfVariablesToSplit = numberOfVariablesToSplit; this.minimumSizeOfLeafNodes = minimumSizeOfLeafNodes; @@ -132,7 +134,8 @@ public ClassificationConfigurationRandomForest clone() { minimumSizeOfLeafNodes, maximumNumberOfLeafNodes, subsample, - splitRule); + splitRule, + super.isUseTrainingTestSet()); } /** diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index c47868b007..dad04b0511 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -28,6 +28,7 @@ import org.deidentifier.arx.ARXClassificationConfiguration; import org.deidentifier.arx.ARXFeatureScaling; import org.deidentifier.arx.DataHandleInternal; +import org.deidentifier.arx.RowSet; import org.deidentifier.arx.aggregates.classification.ClassificationDataSpecification; import org.deidentifier.arx.aggregates.classification.ClassificationMethod; import org.deidentifier.arx.aggregates.classification.ClassificationResult; @@ -266,8 +267,6 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, private final WrappedInteger progress; /** Number of classes */ private int numClasses; - /** Number of samples*/ - private int numSamples; /** Random */ private final Random random; /** Measurements */ @@ -325,9 +324,6 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, this.interrupt = interrupt; this.progress = progress; - // Number of records to consider - this.numSamples = getNumSamples(inputHandle.getNumRows(), config); - // Initialize random if (!config.isDeterministic()) { this.random = new Random(); @@ -342,28 +338,54 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, features, clazz, interrupt); - + + // TODO: Check whether training set is available + if (config.isUseTrainingTestSet() && !inputHandle.isSubsetAvailable()) { + throw new IllegalArgumentException("Training and test set can only be used with a subset"); + } + // Number of class values this.numClasses = specification.classMap.size(); + + // Determine folds and samples to consider + int numRecords = getNumSamples(inputHandle, config); + boolean useTrainingTestSet = config.isUseTrainingTestSet(); + List> folds = null; + if (useTrainingTestSet) { + + // Use test and training set + folds = getTestTrainingFolds(inputHandle, numRecords); + + } else { + + // K-fold cross-validation + int k = numRecords > config.getNumFolds() ? config.getNumFolds() : numRecords; + folds = getFolds(inputHandle.getNumRows(), numRecords, k); + } - // Train and evaluate - int k = numSamples > config.getNumFolds() ? config.getNumFolds() : numSamples; - List> folds = getFolds(inputHandle.getNumRows(), numSamples, k); + // Determine number of classifications + int numClassifications = useTrainingTestSet ? folds.get(0).size() : numRecords; // Track int classifications = 0; - double total = 100d / ((double)numSamples * (double)folds.size()); + double total = useTrainingTestSet ? (100d / (double)folds.get(0).size() + folds.get(1).size()) : + (100d / ((double)numClassifications * (double)folds.size())); double done = 0d; // ROC - double[] inputConfidences = new double[numSamples * ( 1 + numClasses)]; - double[] outputConfidences = (inputHandle == outputHandle) ? null : new double[numSamples * ( 1 + numClasses)]; - double[] zerorConfidences = new double[numSamples * ( 1 + numClasses)]; + double[] inputConfidences = new double[numClassifications * ( 1 + numClasses)]; + double[] outputConfidences = (inputHandle == outputHandle) ? null : new double[numClassifications * ( 1 + numClasses)]; + double[] zerorConfidences = new double[numClassifications * ( 1 + numClasses)]; int confidencesIndex = 0; // For each fold as a validation set for (int evaluationFold = 0; evaluationFold < folds.size(); evaluationFold++) { + // Just run one iteration if using test and traning set + if (useTrainingTestSet && evaluationFold > 0) { + break; + } + // Create classifiers ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); @@ -511,38 +533,6 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, this.numMeasurements = classifications; } - /** - * Calculate brier score. - * @param confidences - * @param handle - * @param specification - * @return - */ - private double calculateBrierScore(double[] confidences, DataHandleInternal handle, ClassificationDataSpecification specification) { - // Brier score - double brier = 0d; - int column = specification.classIndex; - int records = 0; - - // For each record - for (int i = 0; i < confidences.length; i += (numClasses + 1)) { - - // Prepare - int row = (int) confidences[i]; - int correctIndex = specification.classMap.get(handle.getValue(row, column, true)); - - // Calculate for this record - int offset = 0; - for (int j = i + 1; j < i + numClasses + 1; j++) { - brier += Math.pow(confidences[j] - (((offset++) == correctIndex) ? 1 : 0), 2); - } - - // Count - records++; - } - return brier / (double) records; - } - /** * Returns the resulting accuracy. Obtained by generating a * classification model from the output (or input) dataset. @@ -552,7 +542,7 @@ private double calculateBrierScore(double[] confidences, DataHandleInternal hand public double getAccuracy() { return this.accuracy; } - + /** * Returns the average error, defined as avg(1d-probability-of-correct-result) for * each classification event. @@ -562,14 +552,6 @@ public double getAccuracy() { public double getAverageError() { return this.averageError; } - - /** - * Returns the brier score of the ZeroR classifier. - * @return - */ - public double getZerorBrierScore() { - return zerorBrierScore; - } /** * Returns the brier score of the classifier trained on output data. @@ -578,14 +560,6 @@ public double getZerorBrierScore() { public double getBrierScore() { return brierScore; } - - /** - * Returns the brier score of the classifier trained on input data. - * @return - */ - public double getOriginalBrierScore() { - return originalBrierScore; - } /** * Returns the brier skill score, defined as 1-(brier output/brier input) @@ -594,7 +568,7 @@ public double getOriginalBrierScore() { public double getBrierSkillScore() { return brierScore == 0d ? 0d : (1 - brierScore / originalBrierScore); } - + /** * Returns the set of class attributes * @return @@ -602,7 +576,7 @@ public double getBrierSkillScore() { public Set getClassValues() { return this.originalROC.keySet(); } - + /** * Returns the number of classes * @return @@ -610,7 +584,7 @@ public Set getClassValues() { public int getNumClasses() { return this.numClasses; } - + /** * Returns the number of measurements * @return @@ -618,7 +592,7 @@ public int getNumClasses() { public int getNumMeasurements() { return this.numMeasurements; } - + /** * Returns the maximal accuracy. Obtained by generating a * classification model from the input dataset. @@ -640,21 +614,20 @@ public double getOriginalAverageError() { } /** - * Returns the ROC curve for this class value calculated using the one-vs-all approach. - * @param clazz + * Returns the brier score of the classifier trained on input data. * @return */ - public ROCCurve getOriginalROCCurve(String clazz) { - return this.originalROC.get(clazz); + public double getOriginalBrierScore() { + return originalBrierScore; } - + /** * Returns the ROC curve for this class value calculated using the one-vs-all approach. * @param clazz * @return */ - public ROCCurve getZeroRROCCurve(String clazz) { - return this.zerorROC.get(clazz); + public ROCCurve getOriginalROCCurve(String clazz) { + return this.originalROC.get(clazz); } /** @@ -665,7 +638,7 @@ public ROCCurve getZeroRROCCurve(String clazz) { public ROCCurve getROCCurve(String clazz) { return this.ROC.get(clazz); } - + /** * Returns the minimal accuracy. Obtained by training a * ZeroR classifier on the input dataset. @@ -675,7 +648,7 @@ public ROCCurve getROCCurve(String clazz) { public double getZeroRAccuracy() { return this.zeroRAccuracy; } - + /** * Returns the average error, defined as avg(1d-probability-of-correct-result) for * each classification event. @@ -686,6 +659,23 @@ public double getZeroRAverageError() { return this.zeroRAverageError; } + /** + * Returns the brier score of the ZeroR classifier. + * @return + */ + public double getZerorBrierScore() { + return zerorBrierScore; + } + + /** + * Returns the ROC curve for this class value calculated using the one-vs-all approach. + * @param clazz + * @return + */ + public ROCCurve getZeroRROCCurve(String clazz) { + return this.zerorROC.get(clazz); + } + @Override public String toString() { StringBuilder builder = new StringBuilder(); @@ -707,7 +697,39 @@ public String toString() { builder.append("}"); return builder.toString(); } - + + /** + * Calculate brier score. + * @param confidences + * @param handle + * @param specification + * @return + */ + private double calculateBrierScore(double[] confidences, DataHandleInternal handle, ClassificationDataSpecification specification) { + // Brier score + double brier = 0d; + int column = specification.classIndex; + int records = 0; + + // For each record + for (int i = 0; i < confidences.length; i += (numClasses + 1)) { + + // Prepare + int row = (int) confidences[i]; + int correctIndex = specification.classMap.get(handle.getValue(row, column, true)); + + // Calculate for this record + int offset = 0; + for (int j = i + 1; j < i + numClasses + 1; j++) { + brier += Math.pow(confidences[j] - (((offset++) == correctIndex) ? 1 : 0), 2); + } + + // Count + records++; + } + return brier / (double) records; + } + /** * Checks whether an interruption happened. */ @@ -716,7 +738,7 @@ private void checkInterrupt() { throw new ComputationInterruptedException("Interrupted"); } } - + /** * Creates the folds * @param numRecords @@ -765,20 +787,59 @@ private List> getFolds(int numRecords, int numSamples, int k) { rows = null; return folds; } - + /** - * Returns the number of samples as the minimum of actual number of rows in - * the dataset and maximal number of rows as specified in config. + * Returns the number of samples as the minimum of actual number of records in + * the dataset and maximal number of records as specified in the config. * - * @param numRows + * @param data * @param config * @return */ - private int getNumSamples(int numRows, ARXClassificationConfiguration config) { - int numSamples = numRows; + private int getNumSamples(DataHandleInternal data, ARXClassificationConfiguration config) { + int numSamples = data.getNumRows(); if (config.getMaxRecords() > 0) { numSamples = Math.min(config.getMaxRecords(), numSamples); } return numSamples; } + + /** + * Returns two folds. First fold for testing, second fold for training + * @param data + * @param maxRows + * @return + */ + private List> getTestTrainingFolds(DataHandleInternal data, int maxRows) { + + // Calculate sizes of sets + int trainingSetSize = Math.min(data.getView().getNumRows(), maxRows); + int testSetSize = Math.min((data.getNumRows() - data.getView().getNumRows()), maxRows); + + // Prepare indexes of training records + List trainingRecords = new ArrayList<>(); + List testRecords = new ArrayList<>(); + RowSet subset = data.getSubset().getSet(); + for (int row = 0; row < data.getNumRows(); row++) { + if (subset.contains(row)) { + trainingRecords.add(row); + } else { + testRecords.add(row); + } + } + + // Shuffle + Collections.shuffle(trainingRecords, random); + Collections.shuffle(testRecords, random); + + // Extract sets of adequate size + trainingRecords = trainingRecords.subList(0, trainingSetSize); + testRecords = testRecords.subList(0, testSetSize); + + // Return + ArrayList> folds = new ArrayList<>(); + folds.add(testRecords); + folds.add(trainingRecords); + return folds; + } }