1
1
"""Test the pipeline module."""
2
+ from typing import Any
3
+ from unittest .mock import MagicMock
2
4
3
5
import pytest
4
6
import torch
15
17
from sparse_autoencoder .activation_resampler .abstract_activation_resampler import (
16
18
ParameterUpdateResults ,
17
19
)
20
+ from sparse_autoencoder .activation_resampler .activation_resampler import ActivationResampler
18
21
from sparse_autoencoder .activation_store .tensor_store import TensorActivationStore
22
+ from sparse_autoencoder .metrics .validate .abstract_validate_metric import (
23
+ AbstractValidationMetric ,
24
+ ValidationMetricData ,
25
+ )
19
26
from sparse_autoencoder .source_data .mock_dataset import MockDataset
20
27
21
28
@@ -39,9 +46,10 @@ def pipeline_fixture() -> Pipeline:
39
46
named_parameters = autoencoder .named_parameters (),
40
47
)
41
48
source_data = MockDataset (context_size = 100 )
49
+ activation_resampler = ActivationResampler (n_learned_features = autoencoder .n_learned_features )
42
50
43
51
return Pipeline (
44
- activation_resampler = None ,
52
+ activation_resampler = activation_resampler ,
45
53
autoencoder = autoencoder ,
46
54
cache_name = "blocks.0.hook_mlp_out" ,
47
55
layer = 0 ,
@@ -239,3 +247,86 @@ def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None:
239
247
dtype = torch .float ,
240
248
),
241
249
), "Optimizer non-dead neuron state should not have changed after training."
250
+
251
+
252
+ class TestValidateSAE :
253
+ """Test the validate_sae method."""
254
+
255
+ def test_reconstruction_loss_more_than_base (self , pipeline_fixture : Pipeline ) -> None :
256
+ """Test that the reconstruction loss is more than the base loss."""
257
+
258
+ # Create a dummy metric, so we can retrieve the stored data afterwards
259
+ class StoreValidationMetric (AbstractValidationMetric ):
260
+ """Dummy metric to store the data."""
261
+
262
+ data : ValidationMetricData | None
263
+
264
+ def calculate (self , data : ValidationMetricData ) -> dict [str , Any ]:
265
+ """Store the data."""
266
+ self .data = data
267
+ return {}
268
+
269
+ dummy_metric = StoreValidationMetric ()
270
+ pipeline_fixture .metrics .validation_metrics .append (dummy_metric )
271
+
272
+ # Run the validation loop
273
+ store_size : int = 1000
274
+ pipeline_fixture .generate_activations (store_size )
275
+ pipeline_fixture .validate_sae (store_size )
276
+
277
+ # Check the loss
278
+ assert (
279
+ dummy_metric .data is not None
280
+ ), "Dummy metric should have stored the data from the validation loop."
281
+ assert (
282
+ dummy_metric .data .source_model_loss_with_reconstruction
283
+ > dummy_metric .data .source_model_loss
284
+ ), "Reconstruction loss should be more than base loss."
285
+
286
+ assert (
287
+ dummy_metric .data .source_model_loss_with_zero_ablation
288
+ > dummy_metric .data .source_model_loss
289
+ ), "Zero ablation loss should be more than base loss."
290
+
291
+
292
+ class TestRunPipeline :
293
+ """Test the run_pipeline method."""
294
+
295
+ def test_run_pipeline_calls_all_methods (self , pipeline_fixture : Pipeline ) -> None :
296
+ """Test that the run_pipeline method calls all the other methods."""
297
+ pipeline_fixture .validate_sae = MagicMock (spec = Pipeline .validate_sae ) # type: ignore
298
+ pipeline_fixture .save_checkpoint = MagicMock (spec = Pipeline .save_checkpoint ) # type: ignore
299
+ pipeline_fixture .activation_resampler .step_resampler = MagicMock ( # type: ignore
300
+ spec = ActivationResampler .step_resampler , return_value = None
301
+ )
302
+
303
+ store_size = 1000
304
+ context_size = pipeline_fixture .source_dataset .context_size
305
+ train_batch_size = store_size // context_size
306
+
307
+ total_loops = 5
308
+ validate_expected_calls = 2
309
+ checkpoint_expected_calls = 5
310
+
311
+ pipeline_fixture .run_pipeline (
312
+ train_batch_size = train_batch_size ,
313
+ max_store_size = store_size ,
314
+ max_activations = store_size * 5 ,
315
+ validation_number_activations = store_size ,
316
+ validate_frequency = store_size * (total_loops // validate_expected_calls ),
317
+ checkpoint_frequency = store_size * (total_loops // checkpoint_expected_calls ),
318
+ )
319
+
320
+ # Check the number of calls
321
+ assert (
322
+ pipeline_fixture .validate_sae .call_count == validate_expected_calls
323
+ ), f"Validate should have been called { validate_expected_calls } times."
324
+
325
+ assert (
326
+ pipeline_fixture .save_checkpoint .call_count == checkpoint_expected_calls
327
+ ), f"Checkpoint should have been called { checkpoint_expected_calls } times."
328
+
329
+ assert (pipeline_fixture .activation_resampler ) is not None
330
+ assert (
331
+ pipeline_fixture .activation_resampler .step_resampler .call_count == total_loops
332
+ ), f"Resampler should have been called { total_loops } times."
0 commit comments