Skip to content

Commit a495994

Browse files
authored
Improve pipeline test coverage (#148)
1 parent 4e488b0 commit a495994

File tree

3 files changed

+104
-6
lines changed

3 files changed

+104
-6
lines changed

sparse_autoencoder/source_data/mock_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,4 @@ def __init__(
174174
dataset_split: Dataset split (e.g. `train`).
175175
"""
176176
self.dataset = ConsecutiveIntHuggingFaceDataset(context_size=context_size) # type: ignore
177+
self.context_size = context_size

sparse_autoencoder/train/pipeline.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ def validate_sae(self, validation_number_activations: int) -> None:
335335
)
336336
for metric in self.metrics.validation_metrics:
337337
calculated = metric.calculate(validation_data)
338-
wandb.log(data=calculated, commit=False)
338+
if wandb.run is not None:
339+
wandb.log(data=calculated, commit=False)
339340

340341
@final
341342
def save_checkpoint(self) -> None:
@@ -411,10 +412,15 @@ def run_pipeline(
411412
)
412413

413414
if parameter_updates is not None:
414-
wandb.log(
415-
{"resample/dead_neurons": len(parameter_updates.dead_neuron_indices)},
416-
commit=False,
417-
)
415+
if wandb.run is not None:
416+
wandb.log(
417+
{
418+
"resample/dead_neurons": len(
419+
parameter_updates.dead_neuron_indices
420+
)
421+
},
422+
commit=False,
423+
)
418424

419425
# Update the parameters
420426
self.update_parameters(parameter_updates)

sparse_autoencoder/train/tests/test_pipeline.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Test the pipeline module."""
2+
from typing import Any
3+
from unittest.mock import MagicMock
24

35
import pytest
46
import torch
@@ -15,7 +17,12 @@
1517
from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
1618
ParameterUpdateResults,
1719
)
20+
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
1821
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
22+
from sparse_autoencoder.metrics.validate.abstract_validate_metric import (
23+
AbstractValidationMetric,
24+
ValidationMetricData,
25+
)
1926
from sparse_autoencoder.source_data.mock_dataset import MockDataset
2027

2128

@@ -39,9 +46,10 @@ def pipeline_fixture() -> Pipeline:
3946
named_parameters=autoencoder.named_parameters(),
4047
)
4148
source_data = MockDataset(context_size=100)
49+
activation_resampler = ActivationResampler(n_learned_features=autoencoder.n_learned_features)
4250

4351
return Pipeline(
44-
activation_resampler=None,
52+
activation_resampler=activation_resampler,
4553
autoencoder=autoencoder,
4654
cache_name="blocks.0.hook_mlp_out",
4755
layer=0,
@@ -239,3 +247,86 @@ def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None:
239247
dtype=torch.float,
240248
),
241249
), "Optimizer non-dead neuron state should not have changed after training."
250+
251+
252+
class TestValidateSAE:
253+
"""Test the validate_sae method."""
254+
255+
def test_reconstruction_loss_more_than_base(self, pipeline_fixture: Pipeline) -> None:
256+
"""Test that the reconstruction loss is more than the base loss."""
257+
258+
# Create a dummy metric, so we can retrieve the stored data afterwards
259+
class StoreValidationMetric(AbstractValidationMetric):
260+
"""Dummy metric to store the data."""
261+
262+
data: ValidationMetricData | None
263+
264+
def calculate(self, data: ValidationMetricData) -> dict[str, Any]:
265+
"""Store the data."""
266+
self.data = data
267+
return {}
268+
269+
dummy_metric = StoreValidationMetric()
270+
pipeline_fixture.metrics.validation_metrics.append(dummy_metric)
271+
272+
# Run the validation loop
273+
store_size: int = 1000
274+
pipeline_fixture.generate_activations(store_size)
275+
pipeline_fixture.validate_sae(store_size)
276+
277+
# Check the loss
278+
assert (
279+
dummy_metric.data is not None
280+
), "Dummy metric should have stored the data from the validation loop."
281+
assert (
282+
dummy_metric.data.source_model_loss_with_reconstruction
283+
> dummy_metric.data.source_model_loss
284+
), "Reconstruction loss should be more than base loss."
285+
286+
assert (
287+
dummy_metric.data.source_model_loss_with_zero_ablation
288+
> dummy_metric.data.source_model_loss
289+
), "Zero ablation loss should be more than base loss."
290+
291+
292+
class TestRunPipeline:
293+
"""Test the run_pipeline method."""
294+
295+
def test_run_pipeline_calls_all_methods(self, pipeline_fixture: Pipeline) -> None:
296+
"""Test that the run_pipeline method calls all the other methods."""
297+
pipeline_fixture.validate_sae = MagicMock(spec=Pipeline.validate_sae) # type: ignore
298+
pipeline_fixture.save_checkpoint = MagicMock(spec=Pipeline.save_checkpoint) # type: ignore
299+
pipeline_fixture.activation_resampler.step_resampler = MagicMock( # type: ignore
300+
spec=ActivationResampler.step_resampler, return_value=None
301+
)
302+
303+
store_size = 1000
304+
context_size = pipeline_fixture.source_dataset.context_size
305+
train_batch_size = store_size // context_size
306+
307+
total_loops = 5
308+
validate_expected_calls = 2
309+
checkpoint_expected_calls = 5
310+
311+
pipeline_fixture.run_pipeline(
312+
train_batch_size=train_batch_size,
313+
max_store_size=store_size,
314+
max_activations=store_size * 5,
315+
validation_number_activations=store_size,
316+
validate_frequency=store_size * (total_loops // validate_expected_calls),
317+
checkpoint_frequency=store_size * (total_loops // checkpoint_expected_calls),
318+
)
319+
320+
# Check the number of calls
321+
assert (
322+
pipeline_fixture.validate_sae.call_count == validate_expected_calls
323+
), f"Validate should have been called {validate_expected_calls} times."
324+
325+
assert (
326+
pipeline_fixture.save_checkpoint.call_count == checkpoint_expected_calls
327+
), f"Checkpoint should have been called {checkpoint_expected_calls} times."
328+
329+
assert (pipeline_fixture.activation_resampler) is not None
330+
assert (
331+
pipeline_fixture.activation_resampler.step_resampler.call_count == total_loops
332+
), f"Resampler should have been called {total_loops} times."

0 commit comments

Comments
 (0)