Skip to content

Commit b17e0f6

Browse files
authored
Improve optimizer documentation (#79)
1 parent a536fcc commit b17e0f6

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed
Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,20 @@
1-
"""Optimizer."""
1+
"""Optimizers for Sparse Autoencoders.
2+
3+
When training a Sparse Autoencoder, it can be necessary to manually edit the model parameters
4+
(e.g. with neuron resampling to prevent dead neurons). When doing this, it's also necessary to
5+
reset the optimizer state for these parameters, as otherwise things like running averages will be
6+
incorrect (e.g. the running averages of the gradients and the squares of gradients with Adam).
7+
8+
The optimizer used in the original [Towards Monosemanticity: Decomposing Language Models With
9+
Dictionary Learning](Towards Monosemanticity: Decomposing Language Models With Dictionary Learning)
10+
paper is available here as :class:`AdamWithReset`.
11+
12+
To enable creating other optimizers with reset methods, we also provide the interface
13+
:class:`AbstractOptimizerWithReset`.
14+
"""
15+
16+
from sparse_autoencoder.optimizer.abstract_optimizer import AbstractOptimizerWithReset
17+
from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
18+
19+
20+
__all__ = ["AdamWithReset", "AbstractOptimizerWithReset"]

sparse_autoencoder/optimizer/abstract_optimizer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
"""Abstract optimizer with reset."""
22
from abc import ABC, abstractmethod
33

4+
from sparse_autoencoder.tensor_types import DeadNeuronIndices
5+
46

57
class AbstractOptimizerWithReset(ABC):
6-
"""Abstract optimizer with reset."""
8+
"""Abstract optimizer with reset.
9+
10+
When implementing this interface, we recommend adding a `named_parameters` argument to the
11+
constructor, which can be obtained from `named_parameters=model.named_parameters()` by the end
12+
user. This is so that the optimizer can find the parameters to reset.
13+
"""
714

815
@abstractmethod
916
def reset_state_all_parameters(self) -> None:
@@ -13,3 +20,28 @@ def reset_state_all_parameters(self) -> None:
1320
parameters (e.g. with activation resampling).
1421
"""
1522
raise NotImplementedError
23+
24+
@abstractmethod
25+
def reset_neurons_state(
26+
self,
27+
parameter_name: str,
28+
neuron_indices: DeadNeuronIndices,
29+
axis: int,
30+
parameter_group: int = 0,
31+
) -> None:
32+
"""Reset the state for specific neurons, on a specific parameter.
33+
34+
Args:
35+
parameter_name: The name of the parameter. Examples from the standard sparse autoencoder
36+
implementation include `tied_bias`, `encoder.Linear.weight`, `encoder.Linear.bias`,
37+
`decoder.Linear.weight`, and `decoder.ConstrainedUnitNormLinear.weight`.
38+
neuron_indices: The indices of the neurons to reset.
39+
axis: The axis of the parameter to reset.
40+
parameter_group: The index of the parameter group to reset (typically this is just zero,
41+
unless you have setup multiple parameter groups for e.g. different learning rates
42+
for different parameters).
43+
44+
Raises:
45+
ValueError: If the parameter name is not found.
46+
"""
47+
raise NotImplementedError

0 commit comments

Comments
 (0)