11"""Test the pipeline module."""
2+ from typing import Any
3+ from unittest .mock import MagicMock
24
35import pytest
46import torch
1517from sparse_autoencoder .activation_resampler .abstract_activation_resampler import (
1618 ParameterUpdateResults ,
1719)
20+ from sparse_autoencoder .activation_resampler .activation_resampler import ActivationResampler
1821from sparse_autoencoder .activation_store .tensor_store import TensorActivationStore
22+ from sparse_autoencoder .metrics .validate .abstract_validate_metric import (
23+ AbstractValidationMetric ,
24+ ValidationMetricData ,
25+ )
1926from sparse_autoencoder .source_data .mock_dataset import MockDataset
2027
2128
@@ -39,9 +46,10 @@ def pipeline_fixture() -> Pipeline:
3946 named_parameters = autoencoder .named_parameters (),
4047 )
4148 source_data = MockDataset (context_size = 100 )
49+ activation_resampler = ActivationResampler (n_learned_features = autoencoder .n_learned_features )
4250
4351 return Pipeline (
44- activation_resampler = None ,
52+ activation_resampler = activation_resampler ,
4553 autoencoder = autoencoder ,
4654 cache_name = "blocks.0.hook_mlp_out" ,
4755 layer = 0 ,
@@ -239,3 +247,86 @@ def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None:
239247 dtype = torch .float ,
240248 ),
241249 ), "Optimizer non-dead neuron state should not have changed after training."
250+
251+
252+ class TestValidateSAE :
253+ """Test the validate_sae method."""
254+
255+ def test_reconstruction_loss_more_than_base (self , pipeline_fixture : Pipeline ) -> None :
256+ """Test that the reconstruction loss is more than the base loss."""
257+
258+ # Create a dummy metric, so we can retrieve the stored data afterwards
259+ class StoreValidationMetric (AbstractValidationMetric ):
260+ """Dummy metric to store the data."""
261+
262+ data : ValidationMetricData | None
263+
264+ def calculate (self , data : ValidationMetricData ) -> dict [str , Any ]:
265+ """Store the data."""
266+ self .data = data
267+ return {}
268+
269+ dummy_metric = StoreValidationMetric ()
270+ pipeline_fixture .metrics .validation_metrics .append (dummy_metric )
271+
272+ # Run the validation loop
273+ store_size : int = 1000
274+ pipeline_fixture .generate_activations (store_size )
275+ pipeline_fixture .validate_sae (store_size )
276+
277+ # Check the loss
278+ assert (
279+ dummy_metric .data is not None
280+ ), "Dummy metric should have stored the data from the validation loop."
281+ assert (
282+ dummy_metric .data .source_model_loss_with_reconstruction
283+ > dummy_metric .data .source_model_loss
284+ ), "Reconstruction loss should be more than base loss."
285+
286+ assert (
287+ dummy_metric .data .source_model_loss_with_zero_ablation
288+ > dummy_metric .data .source_model_loss
289+ ), "Zero ablation loss should be more than base loss."
290+
291+
292+ class TestRunPipeline :
293+ """Test the run_pipeline method."""
294+
295+ def test_run_pipeline_calls_all_methods (self , pipeline_fixture : Pipeline ) -> None :
296+ """Test that the run_pipeline method calls all the other methods."""
297+ pipeline_fixture .validate_sae = MagicMock (spec = Pipeline .validate_sae ) # type: ignore
298+ pipeline_fixture .save_checkpoint = MagicMock (spec = Pipeline .save_checkpoint ) # type: ignore
299+ pipeline_fixture .activation_resampler .step_resampler = MagicMock ( # type: ignore
300+ spec = ActivationResampler .step_resampler , return_value = None
301+ )
302+
303+ store_size = 1000
304+ context_size = pipeline_fixture .source_dataset .context_size
305+ train_batch_size = store_size // context_size
306+
307+ total_loops = 5
308+ validate_expected_calls = 2
309+ checkpoint_expected_calls = 5
310+
311+ pipeline_fixture .run_pipeline (
312+ train_batch_size = train_batch_size ,
313+ max_store_size = store_size ,
314+ max_activations = store_size * 5 ,
315+ validation_number_activations = store_size ,
316+ validate_frequency = store_size * (total_loops // validate_expected_calls ),
317+ checkpoint_frequency = store_size * (total_loops // checkpoint_expected_calls ),
318+ )
319+
320+ # Check the number of calls
321+ assert (
322+ pipeline_fixture .validate_sae .call_count == validate_expected_calls
323+ ), f"Validate should have been called { validate_expected_calls } times."
324+
325+ assert (
326+ pipeline_fixture .save_checkpoint .call_count == checkpoint_expected_calls
327+ ), f"Checkpoint should have been called { checkpoint_expected_calls } times."
328+
329+ assert (pipeline_fixture .activation_resampler ) is not None
330+ assert (
331+ pipeline_fixture .activation_resampler .step_resampler .call_count == total_loops
332+ ), f"Resampler should have been called { total_loops } times."
0 commit comments