Skip to content

Commit 26ae485

Browse files
authored
Add save to wandb (#153)
1 parent 05ed6b3 commit 26ae485

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

sparse_autoencoder/train/pipeline.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Iterable
33
from functools import partial
44
from pathlib import Path
5+
import tempfile
56
from typing import final
67
from urllib.parse import quote_plus
78

@@ -32,6 +33,9 @@
3233
from sparse_autoencoder.train.utils import get_model_device
3334

3435

36+
DEFAULT_CHECKPOINT_DIRECTORY: Path = Path(tempfile.gettempdir()) / "sparse_autoencoder"
37+
38+
3539
class Pipeline:
3640
"""Pipeline for training a Sparse Autoencoder on TransformerLens activations.
3741
@@ -90,7 +94,7 @@ def __init__(
9094
source_dataset: SourceDataset,
9195
source_model: HookedTransformer,
9296
run_name: str = "sparse_autoencoder",
93-
checkpoint_directory: Path | None = None,
97+
checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY,
9498
log_frequency: int = 100,
9599
metrics: MetricsContainer = default_metrics,
96100
source_data_batch_size: int = 12,
@@ -339,15 +343,34 @@ def validate_sae(self, validation_number_activations: int) -> None:
339343
wandb.log(data=calculated, commit=False)
340344

341345
@final
342-
def save_checkpoint(self) -> None:
343-
"""Save the model as a checkpoint."""
344-
if self.checkpoint_directory:
345-
run_name_file_system_safe = quote_plus(self.run_name)
346-
file_path: Path = (
347-
self.checkpoint_directory
348-
/ f"{run_name_file_system_safe}-{self.total_activations_trained_on}.pt"
349-
)
350-
torch.save(self.autoencoder.state_dict(), file_path)
346+
def save_checkpoint(self, *, is_final: bool = False) -> Path:
347+
"""Save the model as a checkpoint.
348+
349+
Args:
350+
is_final: Whether this is the final checkpoint.
351+
352+
Returns:
353+
Path to the saved checkpoint.
354+
"""
355+
# Create the name
356+
name: str = f"{self.run_name}_{'final' if is_final else self.total_activations_trained_on}"
357+
safe_name = quote_plus(name, safe="_")
358+
359+
# Save locally
360+
self.checkpoint_directory.mkdir(parents=True, exist_ok=True)
361+
file_path: Path = self.checkpoint_directory / f"{safe_name}.pt"
362+
torch.save(
363+
self.autoencoder.state_dict(),
364+
file_path,
365+
)
366+
367+
# Upload to wandb
368+
if wandb.run is not None:
369+
artifact = wandb.Artifact(safe_name, type="model")
370+
artifact.add_file(str(file_path))
371+
wandb.log_artifact(artifact)
372+
373+
return file_path
351374

352375
def run_pipeline(
353376
self,
@@ -440,6 +463,9 @@ def run_pipeline(
440463
# Update the progress bar
441464
progress_bar.update(store_size)
442465

466+
# Save the final checkpoint
467+
self.save_checkpoint(is_final=True)
468+
443469
@staticmethod
444470
def stateful_dataloader_iterable(
445471
dataloader: DataLoader[TorchTokenizedPrompts],

sparse_autoencoder/train/sweep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def run_training_pipeline(
217217
hook_point = get_act_name(
218218
hyperparameters["source_model"]["hook_site"], hyperparameters["source_model"]["hook_layer"]
219219
)
220+
220221
pipeline = Pipeline(
221222
activation_resampler=activation_resampler,
222223
autoencoder=autoencoder,

sparse_autoencoder/train/tests/test_pipeline.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Test the pipeline module."""
2-
from typing import Any
2+
from typing import TYPE_CHECKING, Any
33
from unittest.mock import MagicMock
44

55
import pytest
@@ -26,6 +26,10 @@
2626
from sparse_autoencoder.source_data.mock_dataset import MockDataset
2727

2828

29+
if TYPE_CHECKING:
30+
from pathlib import Path
31+
32+
2933
@pytest.fixture()
3034
def pipeline_fixture() -> Pipeline:
3135
"""Fixture to create a Pipeline instance for testing."""
@@ -289,6 +293,22 @@ def calculate(self, data: ValidationMetricData) -> dict[str, Any]:
289293
), "Zero ablation loss should be more than base loss."
290294

291295

296+
class TestSaveCheckpoint:
297+
"""Test the save_checkpoint method."""
298+
299+
def test_saves_locally(self, pipeline_fixture: Pipeline) -> None:
300+
"""Test that the save_checkpoint method saves the checkpoint locally."""
301+
saved_checkpoint: Path = pipeline_fixture.save_checkpoint()
302+
assert saved_checkpoint.exists(), "Checkpoint file should exist."
303+
304+
def test_saves_final(self, pipeline_fixture: Pipeline) -> None:
305+
"""Test that the save_checkpoint method saves the final checkpoint."""
306+
saved_checkpoint: Path = pipeline_fixture.save_checkpoint(is_final=True)
307+
assert (
308+
"final.pt" in saved_checkpoint.name
309+
), "Checkpoint file should be named '<run_name>_final.pt'."
310+
311+
292312
class TestRunPipeline:
293313
"""Test the run_pipeline method."""
294314

@@ -306,15 +326,15 @@ def test_run_pipeline_calls_all_methods(self, pipeline_fixture: Pipeline) -> Non
306326

307327
total_loops = 5
308328
validate_expected_calls = 2
309-
checkpoint_expected_calls = 5
329+
checkpoint_expected_calls = 6 # Includes final
310330

311331
pipeline_fixture.run_pipeline(
312332
train_batch_size=train_batch_size,
313333
max_store_size=store_size,
314334
max_activations=store_size * 5,
315335
validation_number_activations=store_size,
316336
validate_frequency=store_size * (total_loops // validate_expected_calls),
317-
checkpoint_frequency=store_size * (total_loops // checkpoint_expected_calls),
337+
checkpoint_frequency=store_size * (total_loops // checkpoint_expected_calls - 1),
318338
)
319339

320340
# Check the number of calls

0 commit comments

Comments
 (0)