Skip to content

Commit a97626d

Browse files
authored
Add abstract model class (#80)
1 parent b17e0f6 commit a97626d

28 files changed

+1009
-338
lines changed

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@
4343
"rewrap.autoWrap.enabled": true,
4444
"rewrap.wrappingColumn": 100,
4545
"python.analysis.diagnosticMode": "workspace"
46-
}
46+
}

poetry.lock

Lines changed: 133 additions & 130 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
pyright=">=1.1.334"
2525
pytest=">=7"
2626
pytest-cov=">=4"
27+
pytest-timeout="^2.2.0"
2728
ruff=">=0.1.4"
2829
syrupy="^4.6.0"
2930

@@ -116,6 +117,8 @@
116117

117118
[tool.pytest]
118119
cache_dir=".cache/pytest"
120+
durations=3
121+
timeout=60
119122

120123
[tool.pytest.ini_options]
121124
addopts="""--jaxtyping-packages=sparse_autoencoder,beartype.beartype --doctest-modules"""
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Abstract Sparse Autoencoder Model."""
2+
from abc import ABC, abstractmethod
3+
4+
from torch.nn import Module
5+
6+
from sparse_autoencoder.autoencoder.components.abstract_decoder import AbstractDecoder
7+
from sparse_autoencoder.autoencoder.components.abstract_encoder import AbstractEncoder
8+
from sparse_autoencoder.autoencoder.components.abstract_outer_bias import AbstractOuterBias
9+
from sparse_autoencoder.tensor_types import (
10+
InputOutputActivationBatch,
11+
LearnedActivationBatch,
12+
)
13+
14+
15+
class AbstractAutoencoder(Module, ABC):
16+
"""Abstract Sparse Autoencoder Model."""
17+
18+
@property
19+
@abstractmethod
20+
def encoder(self) -> AbstractEncoder:
21+
"""Encoder."""
22+
raise NotImplementedError
23+
24+
@property
25+
@abstractmethod
26+
def decoder(self) -> AbstractDecoder:
27+
"""Decoder."""
28+
raise NotImplementedError
29+
30+
@property
31+
@abstractmethod
32+
def pre_encoder_bias(self) -> AbstractOuterBias:
33+
"""Pre-encoder bias."""
34+
raise NotImplementedError
35+
36+
@property
37+
@abstractmethod
38+
def post_decoder_bias(self) -> AbstractOuterBias:
39+
"""Post-decoder bias."""
40+
raise NotImplementedError
41+
42+
@abstractmethod
43+
def forward(
44+
self,
45+
x: InputOutputActivationBatch,
46+
) -> tuple[
47+
LearnedActivationBatch,
48+
InputOutputActivationBatch,
49+
]:
50+
"""Forward Pass.
51+
52+
Args:
53+
x: Input activations (e.g. activations from an MLP layer in a transformer model).
54+
55+
Returns:
56+
Tuple of learned activations and decoded activations.
57+
"""
58+
raise NotImplementedError
59+
60+
@abstractmethod
61+
def reset_parameters(self) -> None:
62+
"""Reset the parameters."""
63+
raise NotImplementedError
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Abstract Sparse Autoencoder Model."""
2+
from abc import ABC, abstractmethod
3+
from typing import final
4+
5+
import torch
6+
from torch.nn import Module
7+
8+
from sparse_autoencoder.tensor_types import (
9+
DeadDecoderNeuronWeightUpdates,
10+
DecoderWeights,
11+
InputOutputActivationBatch,
12+
InputOutputNeuronIndices,
13+
LearnedActivationBatch,
14+
)
15+
16+
17+
class AbstractDecoder(Module, ABC):
18+
"""Abstract Decoder Module.
19+
20+
Typically includes just a :attr:`weight` parameter.
21+
"""
22+
23+
@property
24+
@abstractmethod
25+
def weight(self) -> DecoderWeights:
26+
"""Weight."""
27+
raise NotImplementedError
28+
29+
@abstractmethod
30+
def forward(
31+
self,
32+
x: LearnedActivationBatch,
33+
) -> InputOutputActivationBatch:
34+
"""Forward Pass.
35+
36+
Args:
37+
x: Learned activations.
38+
39+
Returns:
40+
Decoded activations.
41+
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def reset_parameters(self) -> None:
46+
"""Reset the parameters."""
47+
raise NotImplementedError
48+
49+
@final
50+
def update_dictionary_vectors(
51+
self,
52+
dictionary_vector_indices: InputOutputNeuronIndices,
53+
updated_weights: DeadDecoderNeuronWeightUpdates,
54+
) -> None:
55+
"""Update decoder dictionary vectors.
56+
57+
Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically
58+
this is used when resampling neurons (dictionary vectors) that have died.
59+
60+
Args:
61+
dictionary_vector_indices: Indices of the dictionary vectors to update.
62+
updated_weights: Updated weights for just these dictionary vectors.
63+
"""
64+
if len(dictionary_vector_indices) == 0:
65+
return
66+
67+
with torch.no_grad():
68+
self.weight[dictionary_vector_indices, :] = updated_weights
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Abstract Encoder."""
2+
from abc import ABC, abstractmethod
3+
from typing import final
4+
5+
import torch
6+
from torch.nn import Module
7+
8+
from sparse_autoencoder.tensor_types import (
9+
DeadEncoderNeuronWeightUpdates,
10+
EncoderWeights,
11+
InputOutputActivationBatch,
12+
InputOutputNeuronIndices,
13+
LearnedActivationBatch,
14+
LearntActivationVector,
15+
)
16+
17+
18+
class AbstractEncoder(Module, ABC):
19+
"""Abstract encoder module.
20+
21+
Typically includes :attr:`weights` and :attr:`bias` parameters, as well as an activation
22+
function.
23+
"""
24+
25+
@property
26+
@abstractmethod
27+
def weight(self) -> EncoderWeights:
28+
"""Weight."""
29+
raise NotImplementedError
30+
31+
@property
32+
@abstractmethod
33+
def bias(self) -> LearntActivationVector:
34+
"""Bias."""
35+
raise NotImplementedError
36+
37+
@abstractmethod
38+
def forward(self, x: InputOutputActivationBatch) -> LearnedActivationBatch:
39+
"""Forward pass.
40+
41+
Args:
42+
x: Input activations.
43+
44+
Returns:
45+
Resulting activations.
46+
"""
47+
raise NotImplementedError
48+
49+
@final
50+
def update_dictionary_vectors(
51+
self,
52+
dictionary_vector_indices: InputOutputNeuronIndices,
53+
updated_dictionary_weights: DeadEncoderNeuronWeightUpdates,
54+
) -> None:
55+
"""Update encoder dictionary vectors.
56+
57+
Updates the dictionary vectors (columns in the weight matrix) with the given values.
58+
59+
Args:
60+
dictionary_vector_indices: Indices of the dictionary vectors to update.
61+
updated_dictionary_weights: Updated weights for just these dictionary vectors.
62+
"""
63+
if len(dictionary_vector_indices) == 0:
64+
return
65+
66+
with torch.no_grad():
67+
self.weight[:, dictionary_vector_indices] = updated_dictionary_weights
68+
69+
@final
70+
def update_bias(
71+
self,
72+
update_parameter_indices: InputOutputNeuronIndices,
73+
updated_bias_features: LearntActivationVector | float,
74+
) -> None:
75+
"""Update encoder bias.
76+
77+
Args:
78+
update_parameter_indices: Indices of the bias features to update.
79+
updated_bias_features: Updated bias features for just these indices.
80+
"""
81+
if len(update_parameter_indices) == 0:
82+
return
83+
84+
with torch.no_grad():
85+
self.bias[update_parameter_indices] = updated_bias_features
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Abstract Outer Bias.
2+
3+
This can be extended to create e.g. a pre-encoder and post-decoder bias.
4+
"""
5+
from abc import ABC, abstractmethod
6+
7+
from torch.nn import Module
8+
9+
from sparse_autoencoder.tensor_types import (
10+
InputOutputActivationBatch,
11+
InputOutputActivationVector,
12+
)
13+
14+
15+
class AbstractOuterBias(Module, ABC):
16+
"""Abstract Pre-Encoder or Post-Decoder Bias Module."""
17+
18+
@property
19+
@abstractmethod
20+
def bias(self) -> InputOutputActivationVector:
21+
"""Bias.
22+
23+
May be a reference to a bias parameter in the parent module, if using e.g. a tied bias.
24+
"""
25+
raise NotImplementedError
26+
27+
@abstractmethod
28+
def forward(
29+
self,
30+
x: InputOutputActivationBatch,
31+
) -> InputOutputActivationBatch:
32+
"""Forward Pass.
33+
34+
Args:
35+
x: Input activations (e.g. activations from an MLP layer in a transformer model).
36+
37+
Returns:
38+
Resulting activations.
39+
"""
40+
raise NotImplementedError
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Linear encoder layer."""
2+
import math
3+
from typing import final
4+
5+
import einops
6+
import torch
7+
from torch.nn import Parameter, ReLU, init
8+
9+
from sparse_autoencoder.autoencoder.components.abstract_encoder import AbstractEncoder
10+
from sparse_autoencoder.tensor_types import (
11+
EncoderWeights,
12+
InputOutputActivationBatch,
13+
LearnedActivationBatch,
14+
LearntActivationVector,
15+
)
16+
17+
18+
@final
19+
class LinearEncoder(AbstractEncoder):
20+
"""Linear encoder layer."""
21+
22+
_learnt_features: int
23+
"""Number of learnt features (inputs to this layer)."""
24+
25+
_input_features: int
26+
"""Number of decoded features (outputs from this layer)."""
27+
28+
_weight: EncoderWeights
29+
30+
_bias: LearntActivationVector
31+
32+
@property
33+
def weight(self) -> EncoderWeights:
34+
"""Weight."""
35+
return self._weight
36+
37+
@property
38+
def bias(self) -> LearntActivationVector:
39+
"""Bias."""
40+
return self._bias
41+
42+
activation_function: ReLU
43+
44+
def __init__(
45+
self,
46+
input_features: int,
47+
learnt_features: int,
48+
):
49+
"""Initialize the linear encoder layer."""
50+
super().__init__()
51+
self._learnt_features = learnt_features
52+
self._input_features = input_features
53+
self.activation_function = ReLU()
54+
55+
self._weight = Parameter(
56+
torch.empty(
57+
(learnt_features, input_features),
58+
)
59+
)
60+
61+
self._bias = Parameter(torch.zeros(learnt_features))
62+
63+
self.reset_parameters()
64+
65+
def reset_parameters(self) -> None:
66+
"""Initialize or reset the parameters."""
67+
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
68+
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
69+
# https://github.com/pytorch/pytorch/issues/57109
70+
init.kaiming_uniform_(self._weight, a=math.sqrt(5))
71+
72+
# Bias (approach from nn.Linear)
73+
fan_in = self._weight.size(1)
74+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
75+
init.uniform_(self._bias, -bound, bound)
76+
77+
def forward(self, x: InputOutputActivationBatch) -> LearnedActivationBatch:
78+
"""Forward pass.
79+
80+
Args:
81+
x: Input tensor.
82+
83+
Returns:
84+
Output of the forward pass.
85+
"""
86+
learned_activation_batch: LearnedActivationBatch = einops.einsum(
87+
x,
88+
self.weight,
89+
"batch input_output_feature, \
90+
learnt_feature_dim input_output_feature_dim \
91+
-> batch learnt_feature_dim",
92+
)
93+
94+
learned_activation_batch = einops.einsum(
95+
learned_activation_batch,
96+
self.bias,
97+
"batch learnt_feature_dim, \
98+
learnt_feature_dim -> batch learnt_feature_dim",
99+
)
100+
101+
return self.activation_function(learned_activation_batch)
102+
103+
def extra_repr(self) -> str:
104+
"""String extra representation of the module."""
105+
return f"in_features={self._input_features}, out_features={self._learnt_features}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# serializer version: 1
2+
# name: test_extra_repr
3+
'''
4+
LinearEncoder(
5+
in_features=10, out_features=5
6+
(activation_function): ReLU()
7+
)
8+
'''
9+
# ---

0 commit comments

Comments
 (0)