Skip to content

Commit 5464e8f

Browse files
authored
Switch to using named tuples to improve error handling (#169)
1 parent 29bc1ff commit 5464e8f

File tree

14 files changed

+131
-63
lines changed

14 files changed

+131
-63
lines changed

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Activation resampler."""
2-
from typing import Annotated
2+
from typing import Annotated, NamedTuple
33

44
from einops import rearrange
55
from jaxtyping import Bool, Float, Int64
@@ -22,6 +22,15 @@
2222
from sparse_autoencoder.train.utils import get_model_device
2323

2424

25+
class LossInputActivationsTuple(NamedTuple):
26+
"""Loss and corresponding input activations tuple."""
27+
28+
loss_per_item: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]
29+
input_activations: Float[
30+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
31+
]
32+
33+
2534
class ActivationResampler(AbstractActivationResampler):
2635
"""Activation resampler.
2736
@@ -182,10 +191,7 @@ def compute_loss_and_get_activations(
182191
autoencoder: SparseAutoencoder,
183192
loss_fn: AbstractLoss,
184193
train_batch_size: int,
185-
) -> tuple[
186-
Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)],
187-
Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)],
188-
]:
194+
) -> LossInputActivationsTuple:
189195
"""Compute the loss on a random subset of inputs.
190196
191197
Motivation:
@@ -226,18 +232,18 @@ def compute_loss_and_get_activations(
226232
if batch_idx >= n_batches_required:
227233
break
228234

229-
loss_result = torch.cat(loss_batches).to(model_device)
235+
loss_per_item = torch.cat(loss_batches).to(model_device)
230236
input_activations = torch.cat(input_activations_batches).to(model_device)
231237

232238
# Check we generated enough data
233-
if len(loss_result) < n_inputs:
239+
if len(loss_per_item) < n_inputs:
234240
error_message = (
235241
f"Cannot get {n_inputs} items from the store, "
236-
f"as only {len(loss_result)} were available."
242+
f"as only {len(loss_per_item)} were available."
237243
)
238244
raise ValueError(error_message)
239245

240-
return loss_result, input_activations
246+
return LossInputActivationsTuple(loss_per_item, input_activations)
241247

242248
@staticmethod
243249
def assign_sampling_probabilities(
@@ -440,7 +446,7 @@ def resample_dead_neurons(
440446

441447
# Compute the loss for the current model on a random subset of inputs and get the
442448
# activations.
443-
loss, input_activations = self.compute_loss_and_get_activations(
449+
loss_per_item, input_activations = self.compute_loss_and_get_activations(
444450
store=activation_store,
445451
autoencoder=autoencoder,
446452
loss_fn=loss_fn,
@@ -451,7 +457,7 @@ def resample_dead_neurons(
451457
# square of the autoencoder's loss on that input.
452458
sample_probabilities: Float[
453459
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)
454-
] = self.assign_sampling_probabilities(loss)
460+
] = self.assign_sampling_probabilities(loss_per_item)
455461

456462
# For each dead neuron sample an input according to these probabilities.
457463
sampled_input: list[

sparse_autoencoder/autoencoder/abstract_autoencoder.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
"""Abstract Sparse Autoencoder Model."""
22
from abc import ABC, abstractmethod
3+
from typing import NamedTuple
34

45
from jaxtyping import Float
56
from torch import Tensor
6-
from torch.nn import Module, Parameter
7+
from torch.nn import Module
78

89
from sparse_autoencoder.autoencoder.components.abstract_decoder import AbstractDecoder
910
from sparse_autoencoder.autoencoder.components.abstract_encoder import AbstractEncoder
1011
from sparse_autoencoder.autoencoder.components.abstract_outer_bias import AbstractOuterBias
12+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1113
from sparse_autoencoder.tensor_types import Axis
1214

1315

16+
class AutoencoderForwardPassResult(NamedTuple):
17+
"""Autoencoder Forward Pass Result."""
18+
19+
learned_activations: Float[
20+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
21+
]
22+
23+
decoded_activations: Float[
24+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
25+
]
26+
27+
1428
class AbstractAutoencoder(Module, ABC):
1529
"""Abstract Sparse Autoencoder Model.
1630
@@ -42,7 +56,7 @@ def post_decoder_bias(self) -> AbstractOuterBias:
4256
"""Post-decoder bias."""
4357

4458
@property
45-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
59+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
4660
"""Reset optimizer parameter details.
4761
4862
Details of the parameters that should be reset in the optimizer, when resetting
@@ -63,10 +77,7 @@ def forward(
6377
x: Float[
6478
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
6579
],
66-
) -> tuple[
67-
Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)],
68-
Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)],
69-
]:
80+
) -> AutoencoderForwardPassResult:
7081
"""Forward Pass.
7182
7283
Args:

sparse_autoencoder/autoencoder/components/abstract_decoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import Tensor
99
from torch.nn import Module, Parameter
1010

11+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1112
from sparse_autoencoder.tensor_types import Axis
1213

1314

@@ -60,7 +61,7 @@ def weight(
6061

6162
@property
6263
@abstractmethod
63-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
64+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
6465
"""Reset optimizer parameter details.
6566
6667
Details of the parameters that should be reset in the optimizer, when resetting

sparse_autoencoder/autoencoder/components/abstract_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import Tensor
99
from torch.nn import Module, Parameter
1010

11+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1112
from sparse_autoencoder.tensor_types import Axis
1213

1314

@@ -66,7 +67,7 @@ def bias(self) -> Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEAR
6667

6768
@property
6869
@abstractmethod
69-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
70+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
7071
"""Reset optimizer parameter details.
7172
7273
Details of the parameters that should be reset in the optimizer, when resetting

sparse_autoencoder/autoencoder/components/linear_encoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.nn import Parameter, ReLU, init
1111

1212
from sparse_autoencoder.autoencoder.components.abstract_encoder import AbstractEncoder
13+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1314
from sparse_autoencoder.tensor_types import Axis
1415
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
1516

@@ -63,7 +64,7 @@ def bias(self) -> Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEAR
6364
return self._bias
6465

6566
@property
66-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
67+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
6768
"""Reset optimizer parameter details.
6869
6970
Details of the parameters that should be reset in the optimizer, when resetting
@@ -73,7 +74,10 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
7374
List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
7475
reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
7576
"""
76-
return [(self.weight, -2), (self.bias, -1)]
77+
return [
78+
ResetOptimizerParameterDetails(parameter=self.weight, axis=-2),
79+
ResetOptimizerParameterDetails(parameter=self.bias, axis=-1),
80+
]
7781

7882
activation_function: ReLU
7983
"""Activation function."""

sparse_autoencoder/autoencoder/components/tests/test_abstract_decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch.nn import Parameter, init
1010

1111
from sparse_autoencoder.autoencoder.components.abstract_decoder import AbstractDecoder
12+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1213
from sparse_autoencoder.tensor_types import Axis
1314

1415

@@ -46,9 +47,9 @@ def weight(
4647
return self._weight
4748

4849
@property
49-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
50+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
5051
"""Reset optimizer parameter details."""
51-
return [(self.weight, 1)]
52+
return [ResetOptimizerParameterDetails(parameter=self.weight, axis=1)]
5253

5354
def forward(
5455
self, x: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)]

sparse_autoencoder/autoencoder/components/tests/test_abstract_encoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sparse_autoencoder.autoencoder.components.abstract_encoder import (
1212
AbstractEncoder,
1313
)
14+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1415
from sparse_autoencoder.tensor_types import Axis
1516

1617

@@ -50,9 +51,12 @@ def reset_parameters(self) -> None:
5051
self._weight: Parameter = init.kaiming_normal_(self._weight) # type: ignore
5152

5253
@property
53-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
54+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
5455
"""Reset optimizer parameter details."""
55-
return [(self.weight, 0), (self.bias, 0)]
56+
return [
57+
ResetOptimizerParameterDetails(parameter=self.weight, axis=0),
58+
ResetOptimizerParameterDetails(parameter=self.bias, axis=0),
59+
]
5660

5761

5862
@pytest.fixture()

sparse_autoencoder/autoencoder/components/unit_norm_decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch.nn import Parameter, init
1010

1111
from sparse_autoencoder.autoencoder.components.abstract_decoder import AbstractDecoder
12+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
1213
from sparse_autoencoder.tensor_types import Axis
1314
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
1415

@@ -65,7 +66,7 @@ def weight(
6566
return self._weight
6667

6768
@property
68-
def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
69+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
6970
"""Reset optimizer parameter details.
7071
7172
Details of the parameters that should be reset in the optimizer, when resetting
@@ -75,7 +76,7 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
7576
List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
7677
reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
7778
"""
78-
return [(self.weight, -1)]
79+
return [ResetOptimizerParameterDetails(parameter=self.weight, axis=-1)]
7980

8081
@validate_call
8182
def __init__(

sparse_autoencoder/autoencoder/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from torch import Tensor
99
from torch.nn.parameter import Parameter
1010

11-
from sparse_autoencoder.autoencoder.abstract_autoencoder import AbstractAutoencoder
11+
from sparse_autoencoder.autoencoder.abstract_autoencoder import (
12+
AbstractAutoencoder,
13+
AutoencoderForwardPassResult,
14+
)
1215
from sparse_autoencoder.autoencoder.components.linear_encoder import LinearEncoder
1316
from sparse_autoencoder.autoencoder.components.tied_bias import TiedBias, TiedBiasPosition
1417
from sparse_autoencoder.autoencoder.components.unit_norm_decoder import UnitNormDecoder
@@ -139,10 +142,7 @@ def forward(
139142
x: Float[
140143
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
141144
],
142-
) -> tuple[
143-
Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)],
144-
Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)],
145-
]:
145+
) -> AutoencoderForwardPassResult:
146146
"""Forward Pass.
147147
148148
Args:
@@ -155,7 +155,8 @@ def forward(
155155
learned_activations = self._encoder(x)
156156
x = self._decoder(learned_activations)
157157
decoded_activations = self._post_decoder_bias(x)
158-
return learned_activations, decoded_activations
158+
159+
return AutoencoderForwardPassResult(learned_activations, decoded_activations)
159160

160161
def initialize_tied_parameters(self) -> None:
161162
"""Initialize the tied parameters."""
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Autoencoder types."""
2+
from typing import NamedTuple
3+
4+
from torch.nn import Parameter
5+
6+
7+
class ResetOptimizerParameterDetails(NamedTuple):
8+
"""Reset Optimizer Parameter Details.
9+
10+
Details of a parameter that should be reset in the optimizer, when resetting
11+
it's corresponding dictionary vectors.
12+
"""
13+
14+
parameter: Parameter
15+
"""Parameter to reset."""
16+
17+
axis: int
18+
"""Axis of the parameter to reset."""

0 commit comments

Comments
 (0)