Skip to content

Commit fc018da

Browse files
committed
feat: 🚧 New evaluation script - on zarr
1 parent e05fadd commit fc018da

File tree

5 files changed

+448
-9
lines changed

5 files changed

+448
-9
lines changed

fictus/05_evaluate.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import cv2
2+
import dask.array as da
3+
import kornia.morphology as morph
4+
import logging
5+
import numpy as np
6+
from pathlib import Path
7+
from quac.training.config import ExperimentConfig
8+
from quac.training.data_loader import get_test_loader
9+
import torch
10+
from torch.utils.data import Subset
11+
from yaml import safe_load
12+
import zarr
13+
14+
15+
def evaluate(
16+
config_path: str,
17+
kind: str,
18+
batch_size: int,
19+
num_samples: int,
20+
attribution_name: str,
21+
struct_size: int = 11,
22+
):
23+
# Load metadata
24+
with open(config_path, "r") as f:
25+
metadata = safe_load(f)
26+
experiment = ExperimentConfig(**metadata)
27+
experiment_dir = Path(experiment.solver.root_dir)
28+
logging.info(f"Experiment directory {str(experiment_dir)}")
29+
30+
# Load the classifier
31+
logging.info("Loading classifier")
32+
classifier_checkpoint = Path(experiment.validation_config.classifier_checkpoint)
33+
classifier = torch.jit.load(classifier_checkpoint)
34+
classifier.eval()
35+
36+
# Load the data
37+
logging.info("Loading input data")
38+
data_config = experiment.test_data
39+
if data_config is None:
40+
logging.warning("Test data not found in metadata, using validation data")
41+
data_config = experiment.validation_data
42+
# Load the data
43+
dataset = get_test_loader(
44+
data_config.source,
45+
img_size=data_config.img_size,
46+
mean=data_config.mean,
47+
std=data_config.std,
48+
return_dataset=True,
49+
)
50+
# Get the Zarr file in which things are kept
51+
logging.info("Loading generated data")
52+
zarr_file = zarr.open(experiment_dir / "output.zarr", "a")
53+
group = zarr_file[kind]
54+
generated_images = group["generated_images"]
55+
method_group = group[attribution_name]
56+
57+
# attributions
58+
attributions = method_group["attributions"]
59+
60+
# Morphological closing kernel
61+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struct_size, struct_size))
62+
kernel = torch.tensor(kernel).float()
63+
64+
for source, source_name in enumerate(dataset.classes):
65+
is_source = np.where(np.array(dataset.targets) == source)[0]
66+
source_dataset = Subset(dataset, is_source)
67+
logging.info(f"Length of source dataset {len(source_dataset)}")
68+
for target, target_name in enumerate(dataset.classes):
69+
if source == target:
70+
continue
71+
logging.info(f"Running for source {source_name} target {target_name}")
72+
generated_image_array = da.from_zarr(
73+
generated_images[f"{source}_{target}"]
74+
) # N, B, C, H, W
75+
attribution_array = da.from_zarr(
76+
attributions[f"{source}_{target}"]
77+
) # N, B, C, H, W
78+
for i in range(num_samples):
79+
for j in range(len(source_dataset)):
80+
# Get source image
81+
image = source_dataset[j][0]
82+
# Get generated image
83+
generated = generated_image_array[j, i]
84+
# Get attribution
85+
attribution = attribution_array[j, i]
86+
# TODO evaluate
87+
# Get array w/ threshold, mask-size, delta-f
88+
# TODO compute scores?
89+
# TODO store optimal mask
90+
pass
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# %% [markdown]
2+
# A test to find out whether we can pick thresholds with percentiles
3+
# And get a good description of mask sizes
4+
# Started as the end of the kornia test
5+
# %%
6+
import cv2
7+
import torch
8+
import kornia
9+
import numpy as np
10+
import matplotlib.pyplot as plt
11+
import zarr
12+
from yaml import safe_load
13+
from pathlib import Path
14+
from torch.utils.data import Subset
15+
from quac.training.config import ExperimentConfig
16+
from quac.training.data_loader import get_test_loader
17+
import logging
18+
19+
# %% [markdown]
20+
# First we load data from the `fictus` experiment.
21+
# %% Setup
22+
config_path = "../configs/stargan_20241013.yml"
23+
kind = "latent"
24+
attribution_name = "DIntegratedGradients"
25+
source = 0
26+
target = 1
27+
struct = 11 # size of the structuring element for the morphological closing
28+
# %%
29+
# Load metadata
30+
with open(config_path, "r") as f:
31+
metadata = safe_load(f)
32+
experiment = ExperimentConfig(**metadata)
33+
experiment_dir = Path(experiment.solver.root_dir)
34+
logging.info(f"Experiment directory {str(experiment_dir)}")
35+
36+
# Load the classifier
37+
logging.info("Loading classifier")
38+
classifier_checkpoint = Path(experiment.validation_config.classifier_checkpoint)
39+
classifier = torch.jit.load(classifier_checkpoint)
40+
classifier.eval()
41+
42+
# Load the data
43+
logging.info("Loading input data")
44+
data_config = experiment.test_data
45+
if data_config is None:
46+
logging.warning("Test data not found in metadata, using validation data")
47+
data_config = experiment.validation_data
48+
# Load the data
49+
dataset = get_test_loader(
50+
data_config.source,
51+
img_size=data_config.img_size,
52+
mean=data_config.mean,
53+
std=data_config.std,
54+
return_dataset=True,
55+
)
56+
is_source = np.where(np.array(dataset.targets) == source)[0]
57+
source_dataset = Subset(dataset, is_source)
58+
# Get the Zarr file in which things are kept
59+
logging.info("Loading generated data")
60+
zarr_file = zarr.open(experiment_dir / "output.zarr", "r")
61+
group = zarr_file[kind]
62+
generated_images = group[f"generated_images/{source}_{target}"]
63+
attributions = group[f"{attribution_name}/attributions/{source}_{target}"]
64+
# %% [markdown]
65+
# Both the attributions and the generated images have shape (N, B, C, H, W)
66+
# This is because I made 2 generated images per input.
67+
# For this test, I will use the first one.
68+
# %%
69+
generated_images = torch.from_numpy(generated_images[:, 0])
70+
attributions = torch.from_numpy(attributions[:, 0])
71+
72+
# %% [markdown]
73+
# We will apply morphological closing to the attributions
74+
# %%
75+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struct, struct))
77+
torch_kernel = torch.tensor(kernel).float().to(device)
78+
79+
dataloader = torch.utils.data.DataLoader(
80+
attributions, batch_size=32, shuffle=False, drop_last=False, pin_memory=True
81+
)
82+
closed_attributions = torch.cat(
83+
[
84+
kornia.morphology.morphology.closing(attr.to(device), kernel=torch_kernel)
85+
for attr in dataloader
86+
]
87+
)
88+
89+
90+
# %% [markdown]
91+
# Let us look at the results, to see if they make sense
92+
fig, (ax1, ax2) = plt.subplots(1, 2)
93+
ax1.imshow(attributions[0].permute((1, 2, 0)))
94+
ax2.imshow(closed_attributions[0].cpu().permute((1, 2, 0)))
95+
96+
97+
# %% [markdown]
98+
# The first thing that we want to check is whether we can use quantiles for thresholding.
99+
# We will compare mask sizes with two different methods of choosing thresholds:
100+
# - using a linear space of 100 thresholds between the 0 and 1
101+
# - using the 0th to 100th percentiles of the attribution
102+
# The goal is to compare the range of mask sizes that we get.
103+
104+
# %%
105+
linear_thresholds = torch.linspace(0, 1, 100)
106+
linear_sizes = []
107+
quantile_sizes = []
108+
for threshold in linear_thresholds:
109+
linear_sizes.append(
110+
(closed_attributions >= threshold).flatten(1).sum(dim=1)
111+
/ closed_attributions.flatten(1).shape[1]
112+
)
113+
quantile_sizes.append(
114+
(
115+
closed_attributions
116+
>= torch.quantile(
117+
closed_attributions.flatten(1),
118+
torch.tensor(threshold).to(device),
119+
dim=1,
120+
)[:, None, None, None]
121+
)
122+
.flatten(1)
123+
.sum(dim=1)
124+
/ closed_attributions.flatten(1).shape[1]
125+
)
126+
127+
linear_sizes = torch.stack(linear_sizes, dim=1).cpu()
128+
quantile_sizes = torch.stack(quantile_sizes, dim=1).cpu()
129+
130+
# %% Plot the results
131+
fig, ax = plt.subplots()
132+
ax.plot(linear_thresholds, linear_sizes.mean(dim=0), label="Linear")
133+
ax.fill_between(
134+
linear_thresholds,
135+
linear_sizes.mean(dim=0) - linear_sizes.std(dim=0),
136+
linear_sizes.mean(dim=0) + linear_sizes.std(dim=0),
137+
alpha=0.3,
138+
)
139+
ax.plot(linear_thresholds, quantile_sizes.mean(dim=0), label="Quantile")
140+
ax.fill_between(
141+
linear_thresholds,
142+
quantile_sizes.mean(dim=0) - quantile_sizes.std(dim=0),
143+
quantile_sizes.mean(dim=0) + quantile_sizes.std(dim=0),
144+
alpha=0.3,
145+
)
146+
ax.set_ylabel("Mask size")
147+
ax.set_xlabel("Threshold/Quantile")
148+
ax.legend()
149+
150+
# %% [markdown]
151+
# Conclusion - I get a much better spread of mask sizes using the quantiles.
152+
153+
# %% [markdown]
154+
# The next thing to do is to create counterfactual images using these masks.
155+
# Given a mask at a certain threshold, a counterfactual image is created as:
156+
# mask_smoothed * generated_image + (1 - mask_smoothed) * source_image
157+
# where mask_smoothed is the mask after a Gaussian blur.
158+
# %% Get source images
159+
source_images = torch.stack([im[0] for im in source_dataset])
160+
161+
162+
# %% check that the images match
163+
def rescale(image):
164+
return (image + 1) / 2
165+
166+
167+
fig, axes = plt.subplots(3, 2, figsize=(6, 8))
168+
for i, (ax1, ax2) in enumerate(axes):
169+
ax1.imshow(rescale(source_images[i].permute((1, 2, 0))).cpu())
170+
ax2.imshow(rescale(generated_images[i].permute((1, 2, 0))).cpu())
171+
ax1.axis("off")
172+
ax2.axis("off")
173+
# %%
174+
with torch.no_grad():
175+
source_classifications = torch.softmax(classifier(source_images.to(device)), dim=1)[
176+
:, target
177+
].cpu()
178+
# %% Create counterfactual images
179+
from tqdm import tqdm
180+
181+
all_sizes = []
182+
all_scores = []
183+
optimal_counterfactuals = []
184+
generated_images = generated_images.float().to(device)
185+
source_images = source_images.float().to(device)
186+
for threshold in tqdm(linear_thresholds):
187+
quantiles = torch.quantile(
188+
closed_attributions.flatten(1),
189+
torch.tensor(threshold).to(device),
190+
dim=1,
191+
)
192+
masks = closed_attributions >= quantiles[:, None, None, None]
193+
masks = masks.float()
194+
kernel_size = 11
195+
sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8
196+
mask_sizes = masks.flatten(2).any(dim=1).sum(dim=1) / masks.flatten(2).shape[2]
197+
all_sizes.append(mask_sizes)
198+
masks = kornia.filters.gaussian_blur2d(
199+
masks, (kernel_size, kernel_size), (sigma, sigma)
200+
)
201+
counterfactuals = masks * generated_images + (1 - masks) * source_images
202+
# classify
203+
with torch.no_grad():
204+
counterfactual_classifications = torch.softmax(
205+
classifier(counterfactuals), dim=1
206+
)[:, target].cpu()
207+
scores = counterfactual_classifications - source_classifications
208+
all_scores.append(scores)
209+
210+
all_sizes = torch.stack(all_sizes, dim=1).cpu()
211+
all_scores = torch.stack(all_scores, dim=1).cpu()
212+
213+
# %% Plot the results
214+
for size, score in zip(all_sizes, all_scores):
215+
plt.plot(size, score, alpha=0.5)
216+
plt.xlabel("Mask size")
217+
plt.ylabel("Score")
218+
219+
# %%
220+
# Get QuAC scores
221+
auc = -torch.trapz(all_scores, all_sizes, dim=1) # baseline/ground truth
222+
plt.hist(auc, bins=20)
223+
# %% [markdown]
224+
# Next, I want to see where the "optimal" threshold is.
225+
# This is calculated by looking at the mask size - the score change, and getting the minimum.
226+
227+
# %%
228+
# Do all thresholds at once, instead of all samples at once
229+
all_sizes = []
230+
all_scores = []
231+
optimal_counterfactuals = []
232+
optimal_masks = []
233+
234+
for source_image, image, attribution in zip(
235+
source_images, generated_images, closed_attributions
236+
):
237+
# Repeat n_samples times
238+
n_samples = 100
239+
thresholds = torch.linspace(0, 1, n_samples)
240+
attributions = attributions.repeat(n_samples, 1, 1, 1)
241+
masks = attributions >= thresholds[:, None, None, None]
242+
masks = masks.float()
243+
244+
kernel_size = 11
245+
sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8
246+
mask_sizes = masks.flatten(2).any(dim=1).sum(dim=1) / masks.flatten(2).shape[2]
247+
all_sizes.append(mask_sizes)
248+
masks = kornia.filters.gaussian_blur2d(
249+
masks, (kernel_size, kernel_size), (sigma, sigma)
250+
)
251+
# Repeat n_samples time
252+
images = generated_images[0].repeat(n_samples, 1, 1, 1)
253+
source_image = source_images[0].repeat(n_samples, 1, 1, 1)
254+
counterfactuals = masks * images + (1 - masks) * source_image
255+
256+
# classify
257+
with torch.no_grad():
258+
counterfactual_classifications = torch.softmax(
259+
classifier(counterfactuals), dim=1
260+
)[:, target].cpu()
261+
scores = counterfactual_classifications - source_classifications[target]
262+
all_scores.append(scores)
263+
# Get the optimal counterfactual and the optimal mask
264+
optimal_index = scores.argmin()
265+
optimal_counterfactuals.append(counterfactuals[optimal_index])
266+
optimal_masks.append(masks[optimal_index])

fictus/configs/stargan.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
project: "fictus-stargan"
2-
name: "stargan_v0"
3-
notes: "Stargan training on fictus dataset"
4-
tags:
5-
- fictus
6-
- stargan
7-
- training
8-
- quac
1+
log:
2+
project: "fictus-stargan"
3+
name: "stargan_v0"
4+
notes: "Stargan training on fictus dataset"
5+
tags:
6+
- fictus
7+
- stargan
8+
- training
9+
- quac
910

1011
data:
1112
source: "/nrs/funke/adjavond/data/fictus/aggregatum/train"

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ dynamic = ["version"]
1919
dependencies = [
2020
"matplotlib",
2121
"GitPython",
22-
"pydantic"
22+
"pydantic",
23+
"kornia",
24+
"scikit-image"
2325
]
2426

2527
[project.optional-dependencies]

0 commit comments

Comments
 (0)