Skip to content

Commit a2349eb

Browse files
authored
Add feature density metric (#11)
1 parent d896390 commit a2349eb

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Feature density metrics & histogram."""
2+
3+
import einops
4+
from jaxtyping import Float
5+
from numpy import histogram
6+
import numpy as np
7+
from numpy.typing import NDArray
8+
import torch
9+
from torch import Tensor
10+
import wandb
11+
12+
13+
def calc_feature_density(
14+
activations: Float[Tensor, "sample activation"], threshold: float = 0.001
15+
) -> Float[Tensor, " activation"]:
16+
"""Count how many times each feature was active.
17+
18+
Percentage of samples in which each feature was active (i.e. the neuron has "fired").
19+
20+
Example:
21+
>>> import torch
22+
>>> activations = torch.tensor([[0.5, 0.5, 0.0], [0.5, 0.0, 0.0001]])
23+
>>> calc_feature_density(activations).tolist()
24+
[1.0, 0.5, 0.0]
25+
26+
Args:
27+
activations: Sample of cached activations (the Autoencoder's learned features).
28+
threshold: Threshold for considering a feature active (i.e. the neuron has "fired"). This
29+
should be close to zero.
30+
31+
Returns:
32+
Number of times each feature was active in a sample.
33+
"""
34+
has_fired: Float[Tensor, "sample activation"] = torch.gt(activations, threshold).to(
35+
# Use float as einops requires this (64 as some features are very sparse)
36+
dtype=torch.float64
37+
)
38+
39+
return einops.reduce(has_fired, "sample activation -> activation", "mean")
40+
41+
42+
def wandb_feature_density_histogram(
43+
feature_density: Float[Tensor, " activation"],
44+
) -> wandb.Histogram:
45+
"""Create a W&B histogram of the feature density.
46+
47+
This can be logged with Weights & Biases using e.g. `wandb.log({"feature_density_histogram":
48+
wandb_feature_density_histogram(feature_density)})`.
49+
50+
Args:
51+
feature_density: Number of times each feature was active in a sample. Can be calculated
52+
using :func:`feature_activity_count`.
53+
54+
Returns:
55+
Weights & Biases histogram for logging with `wandb.log`.
56+
"""
57+
numpy_feature_density: NDArray[np.float_] = feature_density.detach().cpu().numpy()
58+
59+
bins, values = histogram(numpy_feature_density, bins="auto")
60+
return wandb.Histogram(np_histogram=(bins, values))
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Test the feature density metric."""
2+
3+
import torch
4+
5+
from sparse_autoencoder.train.metrics.feature_density import (
6+
calc_feature_density,
7+
wandb_feature_density_histogram,
8+
)
9+
10+
11+
def test_calc_feature_density() -> None:
12+
"""Check that the feature density matches an alternative way of doing the calc."""
13+
activations = torch.tensor([[0.5, 0.5, 0.0], [0.5, 0.0, 0.0001], [0.0, 0.1, 0.0]])
14+
15+
# Use different approach to check
16+
threshold = 0.01
17+
above_threshold = activations > threshold
18+
expected = above_threshold.sum(dim=0, dtype=torch.float64) / above_threshold.shape[0]
19+
20+
res = calc_feature_density(activations)
21+
assert torch.allclose(res, expected), "Output does not match the expected result."
22+
23+
24+
def test_wandb_feature_density_histogram() -> None:
25+
"""Check the Weights & Biases Histogram is created correctly."""
26+
feature_density = torch.tensor([0.001, 0.001, 0.001, 0.5, 0.5, 1.0])
27+
res = wandb_feature_density_histogram(feature_density)
28+
29+
# Check 0.001 is in the first bin 3 times
30+
expected_first_bin_value = 3
31+
assert res.histogram[0] == expected_first_bin_value

sparse_autoencoder/train/train_autoencoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,14 @@ def train_autoencoder(
6565
l1_loss_learned_activations,
6666
sweep_parameters.l1_coefficient,
6767
)
68-
# TODO: Log dead neurons metric (get_frequencies in Neel's code)
68+
69+
# TODO: Store the learned activations (default every 25k steps)
6970

7071
# Backwards pass
7172
total_loss.backward()
7273

7374
optimizer.step()
7475

75-
# TODO: Enable neuron resampling here
76-
7776
# Log
7877
if step % log_interval == 0 and wandb.run is not None:
7978
wandb.log(
@@ -84,6 +83,10 @@ def train_autoencoder(
8483
},
8584
)
8685

86+
# TODO: Get the feature density & also log to wandb
87+
88+
# TODO: Apply neuron resampling if enabled
89+
8790
progress_bar.update(batch_size)
8891

8992
progress_bar.close()

0 commit comments

Comments
 (0)