1
1
"""Abstract optimizer with reset."""
2
2
from abc import ABC , abstractmethod
3
3
4
+ from sparse_autoencoder .tensor_types import DeadNeuronIndices
5
+
4
6
5
7
class AbstractOptimizerWithReset (ABC ):
6
- """Abstract optimizer with reset."""
8
+ """Abstract optimizer with reset.
9
+
10
+ When implementing this interface, we recommend adding a `named_parameters` argument to the
11
+ constructor, which can be obtained from `named_parameters=model.named_parameters()` by the end
12
+ user. This is so that the optimizer can find the parameters to reset.
13
+ """
7
14
8
15
@abstractmethod
9
16
def reset_state_all_parameters (self ) -> None :
@@ -13,3 +20,28 @@ def reset_state_all_parameters(self) -> None:
13
20
parameters (e.g. with activation resampling).
14
21
"""
15
22
raise NotImplementedError
23
+
24
+ @abstractmethod
25
+ def reset_neurons_state (
26
+ self ,
27
+ parameter_name : str ,
28
+ neuron_indices : DeadNeuronIndices ,
29
+ axis : int ,
30
+ parameter_group : int = 0 ,
31
+ ) -> None :
32
+ """Reset the state for specific neurons, on a specific parameter.
33
+
34
+ Args:
35
+ parameter_name: The name of the parameter. Examples from the standard sparse autoencoder
36
+ implementation include `tied_bias`, `encoder.Linear.weight`, `encoder.Linear.bias`,
37
+ `decoder.Linear.weight`, and `decoder.ConstrainedUnitNormLinear.weight`.
38
+ neuron_indices: The indices of the neurons to reset.
39
+ axis: The axis of the parameter to reset.
40
+ parameter_group: The index of the parameter group to reset (typically this is just zero,
41
+ unless you have setup multiple parameter groups for e.g. different learning rates
42
+ for different parameters).
43
+
44
+ Raises:
45
+ ValueError: If the parameter name is not found.
46
+ """
47
+ raise NotImplementedError
0 commit comments