Skip to content

Commit c16afdb

Browse files
authored
Add file checkpointing (#96)
1 parent 95f19ca commit c16afdb

File tree

11 files changed

+135
-128
lines changed

11 files changed

+135
-128
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,6 @@ dmypy.json
135135

136136
# Generated docs
137137
docs/content/reference
138+
139+
# Checkpoints directory
140+
.checkpoints

docs/content/demo.ipynb

Lines changed: 60 additions & 84 deletions
Large diffs are not rendered by default.

sparse_autoencoder/activation_resampler/abstract_activation_resampler.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5+
from typing import final
56

67
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
78
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
@@ -35,6 +36,21 @@ class ParameterUpdateResults:
3536
class AbstractActivationResampler(ABC):
3637
"""Abstract activation resampler."""
3738

39+
_resample_dataset_size: int | None = None
40+
"""Resample dataset size.
41+
42+
If none, will use the train dataset size.
43+
"""
44+
45+
@final
46+
def __init__(self, resample_dataset_size: int | None = None) -> None:
47+
"""Initialize the abstract activation resampler.
48+
49+
Args:
50+
resample_dataset_size: Resample dataset size. If none, will use the train dataset size.
51+
"""
52+
self._resample_dataset_size = resample_dataset_size
53+
3854
@abstractmethod
3955
def resample_dead_neurons(
4056
self,
@@ -43,7 +59,6 @@ def resample_dead_neurons(
4359
autoencoder: SparseAutoencoder,
4460
loss_fn: AbstractLoss,
4561
train_batch_size: int,
46-
num_inputs: int = 819_200,
4762
) -> ParameterUpdateResults:
4863
"""Resample dead neurons.
4964
@@ -53,8 +68,5 @@ def resample_dead_neurons(
5368
autoencoder: Sparse autoencoder model.
5469
loss_fn: Loss function.
5570
train_batch_size: Train batch size (also used for resampling).
56-
num_inputs: Number of input activations to use when resampling. Will be rounded down to
57-
be divisible by the batch size, and cannot be larger than the number of items
58-
currently in the store.
5971
"""
6072
raise NotImplementedError

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
SampledDeadNeuronInputs,
2424
TrainBatchStatistic,
2525
)
26+
from sparse_autoencoder.train.utils import get_model_device
2627

2728

2829
class ActivationResampler(AbstractActivationResampler):
@@ -76,12 +77,11 @@ def get_dead_neuron_indices(
7677
"""
7778
return torch.where(neuron_activity <= threshold)[0]
7879

79-
@staticmethod
8080
def compute_loss_and_get_activations(
81+
self,
8182
store: ActivationStore,
8283
autoencoder: SparseAutoencoder,
8384
loss_fn: AbstractLoss,
84-
num_inputs: int,
8585
train_batch_size: int,
8686
) -> tuple[TrainBatchStatistic, InputOutputActivationBatch]:
8787
"""Compute the loss on a random subset of inputs.
@@ -92,7 +92,6 @@ def compute_loss_and_get_activations(
9292
store: Activation store.
9393
autoencoder: Sparse autoencoder model.
9494
loss_fn: Loss function.
95-
num_inputs: Number of input activations to use.
9695
train_batch_size: Train batch size (also used for resampling).
9796
9897
Returns:
@@ -102,19 +101,24 @@ def compute_loss_and_get_activations(
102101
loss_batches: list[TrainBatchStatistic] = []
103102
input_activations_batches: list[InputOutputActivationBatch] = []
104103
dataloader = DataLoader(store, batch_size=train_batch_size)
104+
num_inputs = self._resample_dataset_size or len(store)
105105
batches: int = num_inputs // train_batch_size
106+
model_device: torch.device = get_model_device(autoencoder)
106107

107108
for batch_idx, batch in enumerate(iter(dataloader)):
108109
input_activations_batches.append(batch)
109-
learned_activations, reconstructed_activations = autoencoder(batch)
110+
source_activations = batch.to(model_device)
111+
learned_activations, reconstructed_activations = autoencoder(source_activations)
110112
loss_batches.append(
111-
loss_fn.forward(batch, learned_activations, reconstructed_activations)
113+
loss_fn.forward(
114+
source_activations, learned_activations, reconstructed_activations
115+
)
112116
)
113117
if batch_idx >= batches:
114118
break
115119

116-
loss_result = torch.cat(loss_batches)
117-
input_activations = torch.cat(input_activations_batches)
120+
loss_result = torch.cat(loss_batches).to(model_device)
121+
input_activations = torch.cat(input_activations_batches).to(model_device)
118122

119123
# Check we generated enough data
120124
if len(loss_result) < num_inputs:
@@ -188,7 +192,7 @@ def sample_input(
188192
(0, input_activations.shape[-1]),
189193
dtype=input_activations.dtype,
190194
device=input_activations.device,
191-
)
195+
).to(input_activations.device)
192196

193197
sample_indices: LearntNeuronIndices = torch.multinomial(
194198
probabilities, num_samples=num_samples
@@ -261,7 +265,6 @@ def resample_dead_neurons(
261265
autoencoder: SparseAutoencoder,
262266
loss_fn: AbstractLoss,
263267
train_batch_size: int,
264-
num_inputs: int = 819_200,
265268
) -> ParameterUpdateResults:
266269
"""Resample dead neurons.
267270
@@ -271,9 +274,6 @@ def resample_dead_neurons(
271274
autoencoder: Sparse autoencoder model.
272275
loss_fn: Loss function.
273276
train_batch_size: Train batch size (also used for resampling).
274-
num_inputs: Number of input activations to use when resampling. Will be rounded down
275-
to divisible by the batch size, and cannot be larger than the number of items
276-
currently in the store.
277277
"""
278278
with torch.no_grad():
279279
dead_neuron_indices = self.get_dead_neuron_indices(neuron_activity)
@@ -284,7 +284,6 @@ def resample_dead_neurons(
284284
store=activation_store,
285285
autoencoder=autoencoder,
286286
loss_fn=loss_fn,
287-
num_inputs=num_inputs,
288287
train_batch_size=train_batch_size,
289288
)
290289

@@ -316,7 +315,11 @@ def resample_dead_neurons(
316315
rescaled_sampled_input = self.renormalize_and_scale(
317316
sampled_input, neuron_activity, encoder_weight
318317
)
319-
dead_encoder_bias_updates = torch.zeros_like(dead_neuron_indices, dtype=torch.float)
318+
dead_encoder_bias_updates = torch.zeros_like(
319+
dead_neuron_indices,
320+
dtype=dead_decoder_weight_updates.dtype,
321+
device=dead_decoder_weight_updates.device,
322+
)
320323

321324
return ParameterUpdateResults(
322325
dead_neuron_indices=dead_neuron_indices,

sparse_autoencoder/activation_resampler/tests/test_resample_neurons.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ def test_gets_loss_and_correct_activations(
8989
input_activations_fixture: Tensor,
9090
) -> None:
9191
"""Test it gets loss and also returns the input activations."""
92-
loss, input_activations = ActivationResampler.compute_loss_and_get_activations(
92+
loss, input_activations = ActivationResampler().compute_loss_and_get_activations(
9393
store=activation_store_fixture,
9494
autoencoder=autoencoder_model_fixture,
9595
loss_fn=MSEReconstructionLoss(),
96-
num_inputs=DEFAULT_N_ITEMS,
9796
train_batch_size=DEFAULT_N_ITEMS,
9897
)
9998

@@ -115,11 +114,12 @@ def test_more_items_than_in_store_error(
115114
ValueError,
116115
match=r"Cannot get \d+ items from the store, as only \d+ were available.",
117116
):
118-
ActivationResampler.compute_loss_and_get_activations(
117+
ActivationResampler(
118+
resample_dataset_size=DEFAULT_N_ITEMS + 1
119+
).compute_loss_and_get_activations(
119120
store=activation_store_fixture,
120121
autoencoder=autoencoder_model_fixture,
121122
loss_fn=MSEReconstructionLoss(),
122-
num_inputs=DEFAULT_N_ITEMS + 1,
123123
train_batch_size=DEFAULT_N_ITEMS + 1,
124124
)
125125

@@ -266,7 +266,7 @@ def test_no_changes_if_no_dead_neurons(self) -> None:
266266
model = SparseAutoencoder(5, 10, torch.rand(5))
267267

268268
res = ActivationResampler().resample_dead_neurons(
269-
neuron_activity, store, model, MSEReconstructionLoss(), DEFAULT_N_ITEMS, DEFAULT_N_ITEMS
269+
neuron_activity, store, model, MSEReconstructionLoss(), DEFAULT_N_ITEMS
270270
)
271271

272272
assert res.dead_neuron_indices.numel() == 0, "Should not have any dead neurons"
@@ -290,7 +290,7 @@ def test_updates_a_dead_neuron_parameters(self) -> None:
290290
# Get the current & updated parameters
291291
current_parameters = model.state_dict()
292292
updated_parameters: ParameterUpdateResults = ActivationResampler().resample_dead_neurons(
293-
neuron_activity, store, model, MSEReconstructionLoss(), DEFAULT_N_ITEMS, DEFAULT_N_ITEMS
293+
neuron_activity, store, model, MSEReconstructionLoss(), DEFAULT_N_ITEMS
294294
)
295295

296296
# Check the updated ones have changed

sparse_autoencoder/autoencoder/components/abstract_decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
DeadDecoderNeuronWeightUpdates,
1010
DecoderWeights,
1111
InputOutputActivationBatch,
12-
InputOutputNeuronIndices,
1312
LearnedActivationBatch,
13+
LearntNeuronIndices,
1414
)
1515

1616

@@ -49,7 +49,7 @@ def reset_parameters(self) -> None:
4949
@final
5050
def update_dictionary_vectors(
5151
self,
52-
dictionary_vector_indices: InputOutputNeuronIndices,
52+
dictionary_vector_indices: LearntNeuronIndices,
5353
updated_weights: DeadDecoderNeuronWeightUpdates,
5454
) -> None:
5555
"""Update decoder dictionary vectors.
@@ -65,4 +65,4 @@ def update_dictionary_vectors(
6565
return
6666

6767
with torch.no_grad():
68-
self.weight[dictionary_vector_indices, :] = updated_weights
68+
self.weight[:, dictionary_vector_indices] = updated_weights

sparse_autoencoder/autoencoder/components/abstract_encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
InputOutputNeuronIndices,
1313
LearnedActivationBatch,
1414
LearntActivationVector,
15+
LearntNeuronIndices,
1516
)
1617

1718

@@ -49,7 +50,7 @@ def forward(self, x: InputOutputActivationBatch) -> LearnedActivationBatch:
4950
@final
5051
def update_dictionary_vectors(
5152
self,
52-
dictionary_vector_indices: InputOutputNeuronIndices,
53+
dictionary_vector_indices: LearntNeuronIndices,
5354
updated_dictionary_weights: DeadEncoderNeuronWeightUpdates,
5455
) -> None:
5556
"""Update encoder dictionary vectors.
@@ -64,7 +65,7 @@ def update_dictionary_vectors(
6465
return
6566

6667
with torch.no_grad():
67-
self.weight[:, dictionary_vector_indices] = updated_dictionary_weights
68+
self.weight[dictionary_vector_indices, :] = updated_dictionary_weights
6869

6970
@final
7071
def update_bias(

sparse_autoencoder/autoencoder/components/tests/test_abstract_decoder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def forward(self, x: LearnedActivationBatch) -> InputOutputActivationBatch:
3838

3939
def reset_parameters(self) -> None:
4040
"""Mock reset parameters."""
41-
self._weight: EncoderWeights = init.normal_(self._weight, mean=0, std=1)
41+
self._weight: EncoderWeights = init.kaiming_normal_(
42+
self._weight,
43+
)
4244

4345

4446
@pytest.fixture()
@@ -81,10 +83,10 @@ def test_update_dictionary_vectors_with_no_neurons(mock_decoder: MockDecoder) ->
8183
@pytest.mark.parametrize(
8284
("dictionary_vector_indices", "updates"),
8385
[
84-
(torch.tensor([1]), torch.tensor([[0.5, 0.3, 0.2]])), # Test with 1 neuron to update
86+
(torch.tensor([1]), torch.rand(4, 1)), # Test with 1 neuron to update
8587
(
8688
torch.tensor([0, 2]),
87-
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
89+
torch.rand(4, 2),
8890
), # Test with 2 neurons to update
8991
],
9092
)
@@ -98,5 +100,5 @@ def test_update_dictionary_vectors_with_neurons(
98100

99101
# Check if the specified neurons are updated correctly
100102
assert torch.allclose(
101-
mock_decoder.weight[dictionary_vector_indices, :], updates
103+
mock_decoder.weight[:, dictionary_vector_indices], updates
102104
), "update_dictionary_vectors should update the weights correctly."

sparse_autoencoder/autoencoder/components/tests/test_abstract_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def forward(self, x: LearnedActivationBatch) -> InputOutputActivationBatch:
4545

4646
def reset_parameters(self) -> None:
4747
"""Mock reset parameters."""
48-
self._weight: EncoderWeights = init.normal_(self._weight, mean=0, std=1)
48+
self._weight: EncoderWeights = init.kaiming_normal_(self._weight)
4949

5050

5151
@pytest.fixture()
@@ -89,10 +89,10 @@ def test_update_dictionary_vectors_with_no_neurons(mock_encoder: MockEncoder) ->
8989
@pytest.mark.parametrize(
9090
("dictionary_vector_indices", "updates"),
9191
[
92-
(torch.tensor([1]), torch.rand((3, 1))), # Test with 1 neuron to update
92+
(torch.tensor([1]), torch.rand((1, 4))), # Test with 1 neuron to update
9393
(
9494
torch.tensor([0, 2]),
95-
torch.rand((3, 2)),
95+
torch.rand((2, 4)),
9696
), # Test with 2 neurons to update
9797
],
9898
)
@@ -106,5 +106,5 @@ def test_update_dictionary_vectors_with_neurons(
106106

107107
# Check if the specified neurons are updated correctly
108108
assert torch.allclose(
109-
mock_encoder.weight[:, dictionary_vector_indices], updates
109+
mock_encoder.weight[dictionary_vector_indices, :], updates
110110
), "update_dictionary_vectors should update the weights correctly."

sparse_autoencoder/train/abstract_pipeline.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Abstract pipeline."""
22
from abc import ABC, abstractmethod
33
from collections.abc import Iterable
4+
from pathlib import Path
45
from typing import final
56

7+
import torch
68
from torch.utils.data import DataLoader
79
from tqdm.auto import tqdm
810
from transformer_lens import HookedTransformer
@@ -58,6 +60,8 @@ class AbstractPipeline(ABC):
5860

5961
progress_bar: tqdm | None
6062

63+
total_training_steps: int = 1
64+
6165
@final
6266
def __init__( # noqa: PLR0913
6367
self,
@@ -73,6 +77,7 @@ def __init__( # noqa: PLR0913
7377
train_metrics: list[AbstractTrainMetric] | None = None,
7478
validation_metrics: list[AbstractValidationMetric] | None = None,
7579
source_data_batch_size: int = 12,
80+
checkpoint_directory: Path | None = None,
7681
):
7782
"""Initialize the pipeline."""
7883
self.cache_name = cache_name
@@ -87,6 +92,7 @@ def __init__( # noqa: PLR0913
8792
self.optimizer = optimizer
8893
self.loss = loss
8994
self.source_data_batch_size = source_data_batch_size
95+
self.checkpoint_directory = checkpoint_directory
9096

9197
source_dataloader = source_dataset.get_dataloader(source_data_batch_size)
9298
self.source_data = self.stateful_dataloader_iterable(source_dataloader)
@@ -149,10 +155,14 @@ def validate_sae(self) -> None:
149155
"""Get validation metrics."""
150156
raise NotImplementedError
151157

152-
@abstractmethod
158+
@final
153159
def save_checkpoint(self) -> None:
154160
"""Save the model as a checkpoint."""
155-
raise NotImplementedError
161+
if self.checkpoint_directory:
162+
file_path: Path = (
163+
self.checkpoint_directory / f"sae_state_dict-{self.total_training_steps}.pt"
164+
)
165+
torch.save(self.autoencoder.state_dict(), file_path)
156166

157167
@final
158168
def run_pipeline(
@@ -196,6 +206,11 @@ def run_pipeline(
196206
else:
197207
neuron_activity = detached_neuron_activity
198208

209+
# Update the counters
210+
last_resampled += store_size
211+
last_validated += store_size
212+
last_checkpoint += store_size
213+
199214
# Resample dead neurons (if needed)
200215
progress_bar.set_postfix({"stage": "resample"})
201216
if last_resampled > resample_frequency and self.activation_resampler is not None:

0 commit comments

Comments
 (0)