|
| 1 | +# %% [markdown] |
| 2 | +# A test to find out whether we can pick thresholds with percentiles |
| 3 | +# And get a good description of mask sizes |
| 4 | +# Started as the end of the kornia test |
| 5 | +# %% |
| 6 | +import cv2 |
| 7 | +import torch |
| 8 | +import kornia |
| 9 | +import numpy as np |
| 10 | +import matplotlib.pyplot as plt |
| 11 | +import zarr |
| 12 | +from yaml import safe_load |
| 13 | +from pathlib import Path |
| 14 | +from torch.utils.data import Subset |
| 15 | +from quac.training.config import ExperimentConfig |
| 16 | +from quac.training.data_loader import get_test_loader |
| 17 | +import logging |
| 18 | + |
| 19 | +# %% [markdown] |
| 20 | +# First we load data from the `fictus` experiment. |
| 21 | +# %% Setup |
| 22 | +config_path = "../configs/stargan_20241013.yml" |
| 23 | +kind = "latent" |
| 24 | +attribution_name = "DIntegratedGradients" |
| 25 | +source = 0 |
| 26 | +target = 1 |
| 27 | +struct = 11 # size of the structuring element for the morphological closing |
| 28 | +# %% |
| 29 | +# Load metadata |
| 30 | +with open(config_path, "r") as f: |
| 31 | + metadata = safe_load(f) |
| 32 | +experiment = ExperimentConfig(**metadata) |
| 33 | +experiment_dir = Path(experiment.solver.root_dir) |
| 34 | +logging.info(f"Experiment directory {str(experiment_dir)}") |
| 35 | + |
| 36 | +# Load the classifier |
| 37 | +logging.info("Loading classifier") |
| 38 | +classifier_checkpoint = Path(experiment.validation_config.classifier_checkpoint) |
| 39 | +classifier = torch.jit.load(classifier_checkpoint) |
| 40 | +classifier.eval() |
| 41 | + |
| 42 | +# Load the data |
| 43 | +logging.info("Loading input data") |
| 44 | +data_config = experiment.test_data |
| 45 | +if data_config is None: |
| 46 | + logging.warning("Test data not found in metadata, using validation data") |
| 47 | + data_config = experiment.validation_data |
| 48 | +# Load the data |
| 49 | +dataset = get_test_loader( |
| 50 | + data_config.source, |
| 51 | + img_size=data_config.img_size, |
| 52 | + mean=data_config.mean, |
| 53 | + std=data_config.std, |
| 54 | + return_dataset=True, |
| 55 | +) |
| 56 | +is_source = np.where(np.array(dataset.targets) == source)[0] |
| 57 | +source_dataset = Subset(dataset, is_source) |
| 58 | +# Get the Zarr file in which things are kept |
| 59 | +logging.info("Loading generated data") |
| 60 | +zarr_file = zarr.open(experiment_dir / "output.zarr", "r") |
| 61 | +group = zarr_file[kind] |
| 62 | +generated_images = group[f"generated_images/{source}_{target}"] |
| 63 | +attributions = group[f"{attribution_name}/attributions/{source}_{target}"] |
| 64 | +# %% [markdown] |
| 65 | +# Both the attributions and the generated images have shape (N, B, C, H, W) |
| 66 | +# This is because I made 2 generated images per input. |
| 67 | +# For this test, I will use the first one. |
| 68 | +# %% |
| 69 | +generated_images = torch.from_numpy(generated_images[:, 0]) |
| 70 | +attributions = torch.from_numpy(attributions[:, 0]) |
| 71 | + |
| 72 | +# %% [markdown] |
| 73 | +# We will apply morphological closing to the attributions |
| 74 | +# %% |
| 75 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 76 | +kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struct, struct)) |
| 77 | +torch_kernel = torch.tensor(kernel).float().to(device) |
| 78 | + |
| 79 | +dataloader = torch.utils.data.DataLoader( |
| 80 | + attributions, batch_size=32, shuffle=False, drop_last=False, pin_memory=True |
| 81 | +) |
| 82 | +closed_attributions = torch.cat( |
| 83 | + [ |
| 84 | + kornia.morphology.morphology.closing(attr.to(device), kernel=torch_kernel) |
| 85 | + for attr in dataloader |
| 86 | + ] |
| 87 | +) |
| 88 | + |
| 89 | + |
| 90 | +# %% [markdown] |
| 91 | +# Let us look at the results, to see if they make sense |
| 92 | +fig, (ax1, ax2) = plt.subplots(1, 2) |
| 93 | +ax1.imshow(attributions[0].permute((1, 2, 0))) |
| 94 | +ax2.imshow(closed_attributions[0].cpu().permute((1, 2, 0))) |
| 95 | + |
| 96 | + |
| 97 | +# %% [markdown] |
| 98 | +# The first thing that we want to check is whether we can use quantiles for thresholding. |
| 99 | +# We will compare mask sizes with two different methods of choosing thresholds: |
| 100 | +# - using a linear space of 100 thresholds between the 0 and 1 |
| 101 | +# - using the 0th to 100th percentiles of the attribution |
| 102 | +# The goal is to compare the range of mask sizes that we get. |
| 103 | + |
| 104 | +# %% |
| 105 | +linear_thresholds = torch.linspace(0, 1, 100) |
| 106 | +linear_sizes = [] |
| 107 | +quantile_sizes = [] |
| 108 | +for threshold in linear_thresholds: |
| 109 | + linear_sizes.append( |
| 110 | + (closed_attributions >= threshold).flatten(1).sum(dim=1) |
| 111 | + / closed_attributions.flatten(1).shape[1] |
| 112 | + ) |
| 113 | + quantile_sizes.append( |
| 114 | + ( |
| 115 | + closed_attributions |
| 116 | + >= torch.quantile( |
| 117 | + closed_attributions.flatten(1), |
| 118 | + torch.tensor(threshold).to(device), |
| 119 | + dim=1, |
| 120 | + )[:, None, None, None] |
| 121 | + ) |
| 122 | + .flatten(1) |
| 123 | + .sum(dim=1) |
| 124 | + / closed_attributions.flatten(1).shape[1] |
| 125 | + ) |
| 126 | + |
| 127 | +linear_sizes = torch.stack(linear_sizes, dim=1).cpu() |
| 128 | +quantile_sizes = torch.stack(quantile_sizes, dim=1).cpu() |
| 129 | + |
| 130 | +# %% Plot the results |
| 131 | +fig, ax = plt.subplots() |
| 132 | +ax.plot(linear_thresholds, linear_sizes.mean(dim=0), label="Linear") |
| 133 | +ax.fill_between( |
| 134 | + linear_thresholds, |
| 135 | + linear_sizes.mean(dim=0) - linear_sizes.std(dim=0), |
| 136 | + linear_sizes.mean(dim=0) + linear_sizes.std(dim=0), |
| 137 | + alpha=0.3, |
| 138 | +) |
| 139 | +ax.plot(linear_thresholds, quantile_sizes.mean(dim=0), label="Quantile") |
| 140 | +ax.fill_between( |
| 141 | + linear_thresholds, |
| 142 | + quantile_sizes.mean(dim=0) - quantile_sizes.std(dim=0), |
| 143 | + quantile_sizes.mean(dim=0) + quantile_sizes.std(dim=0), |
| 144 | + alpha=0.3, |
| 145 | +) |
| 146 | +ax.set_ylabel("Mask size") |
| 147 | +ax.set_xlabel("Threshold/Quantile") |
| 148 | +ax.legend() |
| 149 | + |
| 150 | +# %% [markdown] |
| 151 | +# Conclusion - I get a much better spread of mask sizes using the quantiles. |
| 152 | + |
| 153 | +# %% [markdown] |
| 154 | +# The next thing to do is to create counterfactual images using these masks. |
| 155 | +# Given a mask at a certain threshold, a counterfactual image is created as: |
| 156 | +# mask_smoothed * generated_image + (1 - mask_smoothed) * source_image |
| 157 | +# where mask_smoothed is the mask after a Gaussian blur. |
| 158 | +# %% Get source images |
| 159 | +source_images = torch.stack([im[0] for im in source_dataset]) |
| 160 | + |
| 161 | + |
| 162 | +# %% check that the images match |
| 163 | +def rescale(image): |
| 164 | + return (image + 1) / 2 |
| 165 | + |
| 166 | + |
| 167 | +fig, axes = plt.subplots(3, 2, figsize=(6, 8)) |
| 168 | +for i, (ax1, ax2) in enumerate(axes): |
| 169 | + ax1.imshow(rescale(source_images[i].permute((1, 2, 0))).cpu()) |
| 170 | + ax2.imshow(rescale(generated_images[i].permute((1, 2, 0))).cpu()) |
| 171 | + ax1.axis("off") |
| 172 | + ax2.axis("off") |
| 173 | +# %% |
| 174 | +with torch.no_grad(): |
| 175 | + source_classifications = torch.softmax(classifier(source_images.to(device)), dim=1)[ |
| 176 | + :, target |
| 177 | + ].cpu() |
| 178 | +# %% Create counterfactual images |
| 179 | +from tqdm import tqdm |
| 180 | + |
| 181 | +all_sizes = [] |
| 182 | +all_scores = [] |
| 183 | +optimal_counterfactuals = [] |
| 184 | +generated_images = generated_images.float().to(device) |
| 185 | +source_images = source_images.float().to(device) |
| 186 | +for threshold in tqdm(linear_thresholds): |
| 187 | + quantiles = torch.quantile( |
| 188 | + closed_attributions.flatten(1), |
| 189 | + torch.tensor(threshold).to(device), |
| 190 | + dim=1, |
| 191 | + ) |
| 192 | + masks = closed_attributions >= quantiles[:, None, None, None] |
| 193 | + masks = masks.float() |
| 194 | + kernel_size = 11 |
| 195 | + sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8 |
| 196 | + mask_sizes = masks.flatten(2).any(dim=1).sum(dim=1) / masks.flatten(2).shape[2] |
| 197 | + all_sizes.append(mask_sizes) |
| 198 | + masks = kornia.filters.gaussian_blur2d( |
| 199 | + masks, (kernel_size, kernel_size), (sigma, sigma) |
| 200 | + ) |
| 201 | + counterfactuals = masks * generated_images + (1 - masks) * source_images |
| 202 | + # classify |
| 203 | + with torch.no_grad(): |
| 204 | + counterfactual_classifications = torch.softmax( |
| 205 | + classifier(counterfactuals), dim=1 |
| 206 | + )[:, target].cpu() |
| 207 | + scores = counterfactual_classifications - source_classifications |
| 208 | + all_scores.append(scores) |
| 209 | + |
| 210 | +all_sizes = torch.stack(all_sizes, dim=1).cpu() |
| 211 | +all_scores = torch.stack(all_scores, dim=1).cpu() |
| 212 | + |
| 213 | +# %% Plot the results |
| 214 | +for size, score in zip(all_sizes, all_scores): |
| 215 | + plt.plot(size, score, alpha=0.5) |
| 216 | +plt.xlabel("Mask size") |
| 217 | +plt.ylabel("Score") |
| 218 | + |
| 219 | +# %% |
| 220 | +# Get QuAC scores |
| 221 | +auc = -torch.trapz(all_scores, all_sizes, dim=1) # baseline/ground truth |
| 222 | +plt.hist(auc, bins=20) |
| 223 | +# %% [markdown] |
| 224 | +# Next, I want to see where the "optimal" threshold is. |
| 225 | +# This is calculated by looking at the mask size - the score change, and getting the minimum. |
| 226 | + |
| 227 | +# %% |
| 228 | +# Do all thresholds at once, instead of all samples at once |
| 229 | +all_sizes = [] |
| 230 | +all_scores = [] |
| 231 | +optimal_counterfactuals = [] |
| 232 | +optimal_masks = [] |
| 233 | + |
| 234 | +for source_image, image, attribution in zip( |
| 235 | + source_images, generated_images, closed_attributions |
| 236 | +): |
| 237 | + # Repeat n_samples times |
| 238 | + n_samples = 100 |
| 239 | + thresholds = torch.linspace(0, 1, n_samples) |
| 240 | + attributions = attributions.repeat(n_samples, 1, 1, 1) |
| 241 | + masks = attributions >= thresholds[:, None, None, None] |
| 242 | + masks = masks.float() |
| 243 | + |
| 244 | + kernel_size = 11 |
| 245 | + sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8 |
| 246 | + mask_sizes = masks.flatten(2).any(dim=1).sum(dim=1) / masks.flatten(2).shape[2] |
| 247 | + all_sizes.append(mask_sizes) |
| 248 | + masks = kornia.filters.gaussian_blur2d( |
| 249 | + masks, (kernel_size, kernel_size), (sigma, sigma) |
| 250 | + ) |
| 251 | + # Repeat n_samples time |
| 252 | + images = generated_images[0].repeat(n_samples, 1, 1, 1) |
| 253 | + source_image = source_images[0].repeat(n_samples, 1, 1, 1) |
| 254 | + counterfactuals = masks * images + (1 - masks) * source_image |
| 255 | + |
| 256 | + # classify |
| 257 | + with torch.no_grad(): |
| 258 | + counterfactual_classifications = torch.softmax( |
| 259 | + classifier(counterfactuals), dim=1 |
| 260 | + )[:, target].cpu() |
| 261 | + scores = counterfactual_classifications - source_classifications[target] |
| 262 | + all_scores.append(scores) |
| 263 | + # Get the optimal counterfactual and the optimal mask |
| 264 | + optimal_index = scores.argmin() |
| 265 | + optimal_counterfactuals.append(counterfactuals[optimal_index]) |
| 266 | + optimal_masks.append(masks[optimal_index]) |
0 commit comments