Skip to content

Commit 55566bb

Browse files
authored
Implement Capacity Metric and W&B Histogram Logging (#48)
1 parent 7464977 commit 55566bb

File tree

5 files changed

+142
-1
lines changed

5 files changed

+142
-1
lines changed

.vscode/cspell.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"dunder",
2424
"earlyterminate",
2525
"einops",
26+
"einsum",
2627
"endoftext",
2728
"gelu",
2829
"githistory",
@@ -55,6 +56,7 @@
5556
"penality",
5657
"polysemantic",
5758
"polysemantically",
59+
"polysemanticity",
5860
"precommit",
5961
"pyproject",
6062
"pyright",
@@ -69,6 +71,7 @@
6971
"randperm",
7072
"relu",
7173
"resid",
74+
"rtol",
7275
"runcap",
7376
"sharded",
7477
"snapshottest",

poetry.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Capacity metrics for sets of learned features."""
2+
import einops
3+
from jaxtyping import Float
4+
from numpy import histogram
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
import torch
8+
from torch import Tensor
9+
import wandb
10+
11+
12+
def calc_capacities(features: Float[Tensor, "n_feats feat_dim"]) -> Float[Tensor, " n_feats"]:
13+
"""Calculate capacities.
14+
15+
Measure the capacity of a set of features as defined in [Polysemanticity and Capacity in Neural Networks](https://arxiv.org/pdf/2210.01892.pdf).
16+
17+
Capacity is intuitively measuring the 'proportion of a dimension' assigned to a feature.
18+
Formally it's the ratio of the squared dot product of a feature with itself to the sum of its
19+
squared dot products of all features.
20+
21+
If the features are orthogonal, the capacity is 1. If they are all the same, the capacity is
22+
1/n.
23+
24+
Example:
25+
>>> import torch
26+
>>> orthogonal_features = torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
27+
>>> orthogonal_caps = calc_capacities(orthogonal_features)
28+
>>> orthogonal_caps
29+
tensor([1., 1., 1.])
30+
31+
Args:
32+
features: A collection of features.
33+
34+
Returns:
35+
A 1D tensor of capacities, where each element is the capacity of the corresponding feature.
36+
"""
37+
squared_dot_products = (
38+
einops.einsum(
39+
features, features, "n_feats1 feat_dim, n_feats2 feat_dim -> n_feats1 n_feats2"
40+
)
41+
** 2
42+
)
43+
sum_of_sq_dot = squared_dot_products.sum(dim=-1)
44+
return torch.diag(squared_dot_products) / sum_of_sq_dot
45+
46+
47+
def wandb_capacities_histogram(
48+
capacities: Float[Tensor, " n_feats"],
49+
) -> wandb.Histogram:
50+
"""Create a W&B histogram of the capacities.
51+
52+
This can be logged with Weights & Biases using e.g. `wandb.log({"capacities_histogram":
53+
wandb_capacities_histogram(capacities)})`.
54+
55+
Args:
56+
capacities: Capacity of each feature. Can be calculated using :func:`calc_capacities`.
57+
58+
Returns:
59+
Weights & Biases histogram for logging with `wandb.log`.
60+
"""
61+
numpy_capacities: NDArray[np.float_] = capacities.detach().cpu().numpy()
62+
63+
bins, values = histogram(numpy_capacities, bins=20, range=(0, 1))
64+
return wandb.Histogram(np_histogram=(bins, values))
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# serializer version: 1
2+
# name: test_wandb_capacity_histogram
3+
list([
4+
0,
5+
0,
6+
1,
7+
0,
8+
0,
9+
0,
10+
0,
11+
0,
12+
0,
13+
0,
14+
1,
15+
0,
16+
0,
17+
0,
18+
0,
19+
0,
20+
0,
21+
0,
22+
0,
23+
3,
24+
])
25+
# ---
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Tests for the capacity calculation and histogram creation."""
2+
3+
import math
4+
5+
from jaxtyping import Float
6+
import pytest
7+
from syrupy.session import SnapshotSession
8+
import torch
9+
from torch import Tensor
10+
11+
from sparse_autoencoder.train.metrics.capacity import calc_capacities, wandb_capacities_histogram
12+
13+
14+
@pytest.mark.parametrize(
15+
("features", "expected_capacities"),
16+
[
17+
(
18+
torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
19+
torch.tensor([1.0, 1.0]),
20+
),
21+
(
22+
torch.tensor([[-0.8, -0.8, -0.8], [-0.8, -0.8, -0.8]]),
23+
torch.ones(2) / 2,
24+
),
25+
(
26+
torch.tensor(
27+
[[1.0, 0.0, 0], [1 / math.sqrt(2), 1 / math.sqrt(2), 0.0], [0.0, 0.0, 1.0]]
28+
),
29+
torch.tensor([2 / 3, 2 / 3, 1.0]),
30+
),
31+
],
32+
)
33+
def test_calc_capacities(
34+
features: Float[Tensor, "n_feats feat_dim"], expected_capacities: Float[Tensor, " n_feats"]
35+
) -> None:
36+
"""Check that the capacity calculation is correct."""
37+
capacities = calc_capacities(features)
38+
assert torch.allclose(
39+
capacities, expected_capacities, rtol=1e-3
40+
), "Capacity calculation is incorrect."
41+
42+
43+
def test_wandb_capacity_histogram(snapshot: SnapshotSession) -> None:
44+
"""Check the Weights & Biases Histogram is created correctly."""
45+
capacities = torch.tensor([0.5, 0.1, 1, 1, 1])
46+
res = wandb_capacities_histogram(capacities)
47+
48+
assert res.histogram == snapshot

0 commit comments

Comments
 (0)