Skip to content

Commit b2c821f

Browse files
authored
Add replace and zero activations hooks (#111)
1 parent 9a0052a commit b2c821f

File tree

5 files changed

+181
-26
lines changed

5 files changed

+181
-26
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Replace activations hook."""
2+
from typing import TYPE_CHECKING
3+
4+
from torch import Tensor
5+
from transformer_lens.hook_points import HookPoint
6+
7+
from sparse_autoencoder.autoencoder.abstract_autoencoder import AbstractAutoencoder
8+
9+
10+
if TYPE_CHECKING:
11+
from sparse_autoencoder.tensor_types import InputOutputActivationBatch
12+
13+
14+
def replace_activations_hook(
15+
value: Tensor,
16+
hook: HookPoint, # noqa: ARG001
17+
sparse_autoencoder: AbstractAutoencoder,
18+
) -> Tensor:
19+
"""Replace activations hook.
20+
21+
Args:
22+
value: The activations to replace.
23+
hook: The hook point.
24+
sparse_autoencoder: The sparse autoencoder. This should be pre-initialised with
25+
`functools.partial`.
26+
27+
Returns:
28+
Replaced activations.
29+
"""
30+
# Squash to just have a "*items" and a "batch" dimension
31+
original_shape = value.shape
32+
squashed_value: InputOutputActivationBatch = value.view(-1, value.size(-1))
33+
34+
# Get the output activations from a forward pass of the SAE
35+
_learned_activations, output_activations = sparse_autoencoder.forward(squashed_value)
36+
37+
# Reshape to the original shape
38+
return output_activations.view(*original_shape)

sparse_autoencoder/source_model/store_activations_hook.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,33 @@ def store_activations_hook(
1515
Useful for getting just the specific activations wanted, rather than the full cache.
1616
1717
Example:
18-
First we'll need a source model from TransformerLens and an activation store.
19-
20-
>>> from functools import partial
21-
>>> from transformer_lens import HookedTransformer
22-
>>> from sparse_autoencoder.activation_store.list_store import ListActivationStore
23-
>>> store = ListActivationStore()
24-
>>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
25-
Loaded pretrained model tiny-stories-1M into HookedTransformer
26-
27-
Next we can add the hook to specific neurons (in this case the first MLP neurons), and create
28-
the tokens for a forward pass.
29-
30-
>>> model.add_hook(
31-
... "blocks.0.mlp.hook_post", partial(store_activations_hook, store=store)
32-
... )
33-
>>> tokens = model.to_tokens("Hello world")
34-
>>> tokens.shape
35-
torch.Size([1, 3])
36-
37-
Then when we run the model, we should get one activation vector for each token (as we just have
38-
one batch item). Note we also set `stop_at_layer=1` as we don't need the logits or any other
39-
activations after the hook point that we've specified (in this case the first MLP layer).
40-
41-
>>> _output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required
42-
>>> len(store)
43-
3
18+
First we'll need a source model from TransformerLens and an activation store.
19+
20+
>>> from functools import partial
21+
>>> from transformer_lens import HookedTransformer
22+
>>> from sparse_autoencoder.activation_store.list_store import ListActivationStore
23+
>>> store = ListActivationStore()
24+
>>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
25+
Loaded pretrained model tiny-stories-1M into HookedTransformer
26+
27+
Next we can add the hook to specific neurons (in this case the first MLP neurons), and
28+
create the tokens for a forward pass.
29+
30+
>>> model.add_hook(
31+
... "blocks.0.mlp.hook_post", partial(store_activations_hook, store=store)
32+
... )
33+
>>> tokens = model.to_tokens("Hello world")
34+
>>> tokens.shape
35+
torch.Size([1, 3])
36+
37+
Then when we run the model, we should get one activation vector for each token (as we just
38+
have one batch item). Note we also set `stop_at_layer=1` as we don't need the logits or any
39+
other activations after the hook point that we've specified (in this case the first MLP
40+
layer).
41+
42+
>>> _output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required
43+
>>> len(store)
44+
3
4445
4546
Args:
4647
value: The activations to store.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Replace activations hook tests."""
2+
from functools import partial
3+
4+
import torch
5+
from transformer_lens import HookedTransformer
6+
7+
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
8+
from sparse_autoencoder.source_model.replace_activations_hook import replace_activations_hook
9+
from sparse_autoencoder.tensor_types import BatchTokenizedPrompts
10+
11+
12+
def test_hook_stores_activations() -> None:
13+
"""Test that the hook replaces activations."""
14+
torch.random.manual_seed(0)
15+
source_model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
16+
autoencoder = SparseAutoencoder(source_model.cfg.d_model, source_model.cfg.d_model * 2)
17+
18+
tokens: BatchTokenizedPrompts = source_model.to_tokens("Hello world")
19+
loss_without_hook = source_model.forward(tokens, return_type="loss")
20+
loss_with_hook = source_model.run_with_hooks(
21+
tokens,
22+
return_type="loss",
23+
fwd_hooks=[
24+
(
25+
"blocks.0.hook_mlp_out",
26+
partial(replace_activations_hook, sparse_autoencoder=autoencoder),
27+
)
28+
],
29+
)
30+
31+
# Check it decrease performance (as the sae is untrained so it will output nonsense).
32+
assert torch.all(torch.gt(loss_with_hook, loss_without_hook))
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Test the zero ablate hook."""
2+
import pytest
3+
import torch
4+
from transformer_lens.hook_points import HookPoint
5+
6+
from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
7+
8+
9+
class MockHookPoint(HookPoint):
10+
"""Mock HookPoint class."""
11+
12+
13+
@pytest.fixture()
14+
def mock_hook_point() -> MockHookPoint:
15+
"""Fixture to provide a mock HookPoint instance."""
16+
return MockHookPoint()
17+
18+
19+
def test_zero_ablate_hook_with_standard_tensor(mock_hook_point: MockHookPoint) -> None:
20+
"""Test zero_ablate_hook with a standard tensor.
21+
22+
Args:
23+
mock_hook_point: A mock HookPoint instance.
24+
"""
25+
value = torch.ones(3, 4)
26+
expected = torch.zeros(3, 4)
27+
result = zero_ablate_hook(value, mock_hook_point)
28+
assert torch.equal(result, expected), "The output tensor should contain only zeros."
29+
30+
31+
@pytest.mark.parametrize("shape", [(10,), (5, 5), (2, 3, 4)])
32+
def test_zero_ablate_hook_with_various_shapes(
33+
mock_hook_point: MockHookPoint, shape: tuple[int, ...]
34+
) -> None:
35+
"""Test zero_ablate_hook with tensors of various shapes.
36+
37+
Args:
38+
mock_hook_point: A mock HookPoint instance.
39+
shape: A tuple representing the shape of the tensor.
40+
"""
41+
value = torch.ones(*shape)
42+
expected = torch.zeros(*shape)
43+
result = zero_ablate_hook(value, mock_hook_point)
44+
assert torch.equal(
45+
result, expected
46+
), f"The output tensor should be of shape {shape} with zeros."
47+
48+
49+
def test_float_dtype_maintained(mock_hook_point: MockHookPoint) -> None:
50+
"""Test that the float dtype is maintained.
51+
52+
Args:
53+
mock_hook_point: A mock HookPoint instance.
54+
"""
55+
value = torch.ones(3, 4, dtype=torch.float)
56+
result = zero_ablate_hook(value, mock_hook_point)
57+
assert result.dtype == torch.float, "The output tensor should be of dtype float."
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Zero ablate hook."""
2+
import torch
3+
from torch import Tensor
4+
from transformer_lens.hook_points import HookPoint
5+
6+
7+
def zero_ablate_hook(
8+
value: Tensor,
9+
hook: HookPoint, # noqa: ARG001
10+
) -> Tensor:
11+
"""Zero ablate hook.
12+
13+
Args:
14+
value: The activations to store.
15+
hook: The hook point.
16+
17+
Example:
18+
>>> dummy_hook_point = HookPoint()
19+
>>> value = torch.ones(2, 3)
20+
>>> zero_ablate_hook(value, dummy_hook_point)
21+
tensor([[0., 0., 0.],
22+
[0., 0., 0.]])
23+
24+
Returns:
25+
Replaced activations.
26+
"""
27+
return torch.zeros_like(value)

0 commit comments

Comments
 (0)