Skip to content

Commit 9f62039

Browse files
authored
Add more pipeline tests (#146)
1 parent 8409136 commit 9f62039

File tree

13 files changed

+214
-59
lines changed

13 files changed

+214
-59
lines changed

.vscode/settings.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,8 @@
4141
"python.testing.pytestEnabled": true,
4242
"rewrap.autoWrap.enabled": true,
4343
"rewrap.wrappingColumn": 100,
44-
"python.analysis.diagnosticMode": "workspace"
44+
"pylint.ignorePatterns": [
45+
"*"
46+
]
4547
}
48+

sparse_autoencoder/activation_resampler/abstract_activation_resampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55

6-
from jaxtyping import Float, Int
6+
from jaxtyping import Float, Int, Int64
77
from torch import Tensor
88

99
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
@@ -16,7 +16,7 @@
1616
class ParameterUpdateResults:
1717
"""Parameter update results from resampling dead neurons."""
1818

19-
dead_neuron_indices: Int[Tensor, Axis.LEARNT_FEATURE_IDX]
19+
dead_neuron_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX]
2020
"""Dead neuron indices."""
2121

2222
dead_encoder_weight_updates: Float[

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Activation resampler."""
22
from einops import rearrange
3-
from jaxtyping import Bool, Float, Int
3+
from jaxtyping import Bool, Float, Int64
44
import torch
55
from torch import Tensor
66
from torch.nn import Parameter
@@ -139,13 +139,13 @@ def __init__(
139139
self.neuron_activity_window_end = resample_interval
140140
self.neuron_activity_window_start = resample_interval - n_activations_activity_collate
141141
self._max_n_resamples = max_n_resamples
142-
self._collated_neuron_activity = torch.zeros(n_learned_features, dtype=torch.int32)
142+
self._collated_neuron_activity = torch.zeros(n_learned_features, dtype=torch.int64)
143143
self._resample_dataset_size = resample_dataset_size
144144
self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires
145145

146146
def _get_dead_neuron_indices(
147147
self,
148-
) -> Int[Tensor, Axis.LEARNT_FEATURE_IDX]:
148+
) -> Int64[Tensor, Axis.LEARNT_FEATURE_IDX]:
149149
"""Identify the indices of neurons that are dead.
150150
151151
Identifies any neurons that have fired less than the threshold portion of the collated
@@ -171,7 +171,7 @@ def _get_dead_neuron_indices(
171171
self._collated_neuron_activity <= threshold_is_dead_number_fires
172172
)[0]
173173

174-
return dead_indices.to(dtype=torch.int)
174+
return dead_indices.to(dtype=torch.int64)
175175

176176
def compute_loss_and_get_activations(
177177
self,
@@ -299,15 +299,15 @@ def sample_input(
299299
device=input_activations.device,
300300
).to(input_activations.device)
301301

302-
sample_indices: Int[Tensor, Axis.LEARNT_FEATURE_IDX] = torch.multinomial(
302+
sample_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX] = torch.multinomial(
303303
probabilities, num_samples=num_samples
304304
)
305305
return input_activations[sample_indices, :]
306306

307307
@staticmethod
308308
def renormalize_and_scale(
309309
sampled_input: Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)],
310-
neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE],
310+
neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE],
311311
encoder_weight: Float[
312312
Parameter, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
313313
],
@@ -447,7 +447,7 @@ def resample_dead_neurons(
447447

448448
def step_resampler(
449449
self,
450-
batch_neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE],
450+
batch_neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE],
451451
activation_store: ActivationStore,
452452
autoencoder: SparseAutoencoder,
453453
loss_fn: AbstractLoss,

sparse_autoencoder/activation_resampler/tests/test_activation_resampler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for the resample_neurons module."""
22

3-
from jaxtyping import Float, Int
3+
from jaxtyping import Float, Int, Int64
44
import pytest
55
import torch
66
from torch import Tensor
@@ -256,7 +256,7 @@ class TestRenormalizeAndScale:
256256
@staticmethod
257257
def calculate_expected_output(
258258
sampled_input: Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)],
259-
neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE],
259+
neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE],
260260
encoder_weight: Float[
261261
Parameter, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
262262
],
@@ -288,7 +288,7 @@ def test_basic_renormalization(self) -> None:
288288
sampled_input: Float[
289289
Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
290290
] = torch.tensor([[3.0, 4.0, 5.0]])
291-
neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE] = torch.tensor([1, 0, 1, 0, 1])
291+
neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE] = torch.tensor([1, 0, 1, 0, 1])
292292
encoder_weight: Float[
293293
Parameter, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
294294
] = Parameter(torch.ones((DEFAULT_N_LEARNED_FEATURES, DEFAULT_N_INPUT_FEATURES)))
@@ -323,7 +323,7 @@ def test_no_changes_if_no_dead_neurons(
323323
self, full_activation_store: ActivationStore, autoencoder_model: SparseAutoencoder
324324
) -> None:
325325
"""Check it doesn't change anything if there are no dead neurons."""
326-
neuron_activity = torch.ones(DEFAULT_N_LEARNED_FEATURES, dtype=torch.int32)
326+
neuron_activity = torch.ones(DEFAULT_N_LEARNED_FEATURES, dtype=torch.int64)
327327
resampler = ActivationResampler(
328328
resample_interval=10,
329329
n_activations_activity_collate=10,
@@ -348,7 +348,7 @@ def test_updates_a_dead_neuron_parameters(
348348
self, full_activation_store: ActivationStore, autoencoder_model: SparseAutoencoder
349349
) -> None:
350350
"""Check it updates a dead neuron's parameters."""
351-
neuron_activity = torch.ones(DEFAULT_N_LEARNED_FEATURES, dtype=torch.int32)
351+
neuron_activity = torch.ones(DEFAULT_N_LEARNED_FEATURES, dtype=torch.int64)
352352
dead_neuron_idx = 2
353353
neuron_activity[dead_neuron_idx] = 0
354354

@@ -395,19 +395,19 @@ class TestStepResampler:
395395
@pytest.mark.parametrize(
396396
("neuron_activity", "threshold", "expected_indices"),
397397
[
398-
(torch.tensor([1, 0, 3, 9, 0]), 0.0, torch.tensor([1, 4], dtype=torch.int)),
398+
(torch.tensor([1, 0, 3, 9, 0]), 0.0, torch.tensor([1, 4], dtype=torch.int64)),
399399
(
400400
torch.tensor([1, 2, 3, 4, 5]),
401401
0.0,
402-
torch.tensor([], dtype=torch.int),
402+
torch.tensor([], dtype=torch.int64),
403403
),
404-
(torch.tensor([1, 0, 3, 9, 0]), 0.1, torch.tensor([0, 1, 4], dtype=torch.int)),
405-
(torch.tensor([1, 2, 3, 4, 5]), 0.1, torch.tensor([0], dtype=torch.int)),
404+
(torch.tensor([1, 0, 3, 9, 0]), 0.1, torch.tensor([0, 1, 4], dtype=torch.int64)),
405+
(torch.tensor([1, 2, 3, 4, 5]), 0.1, torch.tensor([0], dtype=torch.int64)),
406406
],
407407
)
408408
def test_gets_dead_neuron_indices(
409409
self,
410-
neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE],
410+
neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE],
411411
threshold: float,
412412
expected_indices: Tensor,
413413
full_activation_store: ActivationStore,
@@ -463,7 +463,7 @@ def test_max_updates(
463463
) -> None:
464464
"""Check if max_updates, resample_interval and n_steps_collate are respected."""
465465
# Create neuron activity to log (with one dead neuron)
466-
neuron_activity_batch_size_1 = torch.ones(DEFAULT_N_LEARNED_FEATURES, dtype=torch.int32)
466+
neuron_activity_batch_size_1 = torch.ones(DEFAULT_N_LEARNED_FEATURES, dtype=torch.int64)
467467
neuron_activity_batch_size_1[2] = 0
468468

469469
resampler = ActivationResampler(

sparse_autoencoder/autoencoder/components/abstract_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import ABC, abstractmethod
33
from typing import final
44

5-
from jaxtyping import Float, Int
5+
from jaxtyping import Float, Int64
66
import torch
77
from torch import Tensor
88
from torch.nn import Module, Parameter
@@ -61,7 +61,7 @@ def reset_parameters(self) -> None:
6161
@final
6262
def update_dictionary_vectors(
6363
self,
64-
dictionary_vector_indices: Int[Tensor, Axis.LEARNT_FEATURE_IDX],
64+
dictionary_vector_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX],
6565
updated_weights: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)],
6666
) -> None:
6767
"""Update decoder dictionary vectors.

sparse_autoencoder/autoencoder/components/abstract_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import ABC, abstractmethod
33
from typing import final
44

5-
from jaxtyping import Float, Int
5+
from jaxtyping import Float, Int64
66
import torch
77
from torch import Tensor
88
from torch.nn import Module, Parameter
@@ -62,7 +62,7 @@ def forward(
6262
@final
6363
def update_dictionary_vectors(
6464
self,
65-
dictionary_vector_indices: Int[Tensor, Axis.LEARNT_FEATURE_IDX],
65+
dictionary_vector_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX],
6666
updated_dictionary_weights: Float[
6767
Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
6868
],
@@ -84,7 +84,7 @@ def update_dictionary_vectors(
8484
@final
8585
def update_bias(
8686
self,
87-
update_parameter_indices: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE],
87+
update_parameter_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE],
8888
updated_bias_features: Float[Tensor, Axis.LEARNT_FEATURE] | float,
8989
) -> None:
9090
"""Update encoder bias.

sparse_autoencoder/autoencoder/components/tests/test_abstract_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import final
44

5-
from jaxtyping import Float, Int
5+
from jaxtyping import Float, Int64
66
import pytest
77
import torch
88
from torch import Tensor
@@ -72,9 +72,9 @@ def test_update_dictionary_vectors_with_no_neurons(mock_decoder: MockDecoder) ->
7272
"""Test update_dictionary_vectors with 0 neurons to update."""
7373
original_weight = mock_decoder.weight.clone() # Save original weight for comparison
7474

75-
dictionary_vector_indices: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE] = torch.empty(
75+
dictionary_vector_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE] = torch.empty(
7676
0,
77-
dtype=torch.int, # Empty tensor with 1 dimension
77+
dtype=torch.int64, # Empty tensor with 1 dimension
7878
)
7979
updates: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)] = torch.empty(
8080
(0, 0),
@@ -101,7 +101,7 @@ def test_update_dictionary_vectors_with_no_neurons(mock_decoder: MockDecoder) ->
101101
)
102102
def test_update_dictionary_vectors_with_neurons(
103103
mock_decoder: MockDecoder,
104-
dictionary_vector_indices: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE],
104+
dictionary_vector_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE],
105105
updates: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)],
106106
) -> None:
107107
"""Test update_dictionary_vectors with 1 or 2 neurons to update."""

sparse_autoencoder/autoencoder/components/tests/test_abstract_encoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import final
44

5-
from jaxtyping import Float, Int
5+
from jaxtyping import Float, Int64
66
import pytest
77
import torch
88
from torch import Tensor
@@ -74,11 +74,11 @@ def test_update_dictionary_vectors_with_no_neurons(mock_encoder: MockEncoder) ->
7474
torch.random.manual_seed(0)
7575
original_weight = mock_encoder.weight.clone() # Save original weight for comparison
7676

77-
dictionary_vector_indices: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE] = torch.empty(
77+
dictionary_vector_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE] = torch.empty(
7878
0,
79-
dtype=torch.int, # Empty tensor with 1 dimension
79+
dtype=torch.int64, # Empty tensor with 1 dimension
8080
)
81-
updates: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE] = torch.empty(
81+
updates: Float[Tensor, Axis.INPUT_OUTPUT_FEATURE] = torch.empty(
8282
(0, 0),
8383
dtype=torch.float, # Empty tensor with 2 dimensions
8484
)
@@ -103,8 +103,8 @@ def test_update_dictionary_vectors_with_no_neurons(mock_encoder: MockEncoder) ->
103103
)
104104
def test_update_dictionary_vectors_with_neurons(
105105
mock_encoder: MockEncoder,
106-
dictionary_vector_indices: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE],
107-
updates: Int[Tensor, Axis.INPUT_OUTPUT_FEATURE],
106+
dictionary_vector_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE],
107+
updates: Float[Tensor, Axis.INPUT_OUTPUT_FEATURE],
108108
) -> None:
109109
"""Test update_dictionary_vectors with 1 or 2 neurons to update."""
110110
mock_encoder.update_dictionary_vectors(dictionary_vector_indices, updates)

sparse_autoencoder/metrics/train/neuron_activity_metric.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
from typing import Any
77

8-
from jaxtyping import Int
8+
from jaxtyping import Int64
99
import numpy as np
1010
from numpy.typing import NDArray
1111
import torch
@@ -43,7 +43,7 @@ class NeuronActivityHorizonData:
4343
_steps_since_last_calculated: int
4444
"""Steps since last calculated."""
4545

46-
_neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE]
46+
_neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE]
4747
"""Neuron activity since inception."""
4848

4949
_thresholds: list[float]
@@ -52,8 +52,8 @@ class NeuronActivityHorizonData:
5252
@property
5353
def _dead_count(self) -> int:
5454
"""Dead count."""
55-
dead_bool_mask: Int[Tensor, Axis.LEARNT_FEATURE] = self._neuron_activity == 0
56-
count_dead: Int[Tensor, Axis.SINGLE_ITEM] = dead_bool_mask.sum()
55+
dead_bool_mask: Int64[Tensor, Axis.LEARNT_FEATURE] = self._neuron_activity == 0
56+
count_dead: Int64[Tensor, Axis.SINGLE_ITEM] = dead_bool_mask.sum()
5757
return int(count_dead.item())
5858

5959
@property
@@ -64,8 +64,8 @@ def _dead_fraction(self) -> float:
6464
@property
6565
def _alive_count(self) -> int:
6666
"""Alive count."""
67-
alive_bool_mask: Int[Tensor, Axis.LEARNT_FEATURE] = self._neuron_activity > 0
68-
count_alive: Int[Tensor, Axis.SINGLE_ITEM] = alive_bool_mask.sum()
67+
alive_bool_mask: Int64[Tensor, Axis.LEARNT_FEATURE] = self._neuron_activity > 0
68+
count_alive: Int64[Tensor, Axis.SINGLE_ITEM] = alive_bool_mask.sum()
6969
return int(count_alive.item())
7070

7171
def _almost_dead(self, threshold: float) -> int | None:
@@ -74,10 +74,10 @@ def _almost_dead(self, threshold: float) -> int | None:
7474
if threshold_in_activations < 1:
7575
return None
7676

77-
almost_dead_bool_mask: Int[Tensor, Axis.LEARNT_FEATURE] = (
77+
almost_dead_bool_mask: Int64[Tensor, Axis.LEARNT_FEATURE] = (
7878
self._neuron_activity < threshold_in_activations
7979
)
80-
count_almost_dead: Int[Tensor, Axis.SINGLE_ITEM] = almost_dead_bool_mask.sum()
80+
count_almost_dead: Int64[Tensor, Axis.SINGLE_ITEM] = almost_dead_bool_mask.sum()
8181
return int(count_almost_dead.item())
8282

8383
@property
@@ -134,14 +134,14 @@ def __init__(
134134
thresholds: Thresholds for almost dead neurons.
135135
"""
136136
self._steps_since_last_calculated = 0
137-
self._neuron_activity = torch.zeros(number_learned_features, dtype=torch.int)
137+
self._neuron_activity = torch.zeros(number_learned_features, dtype=torch.int64)
138138
self._thresholds = thresholds
139139

140140
# Get a precise activation_horizon
141141
self._horizon_steps = approximate_activation_horizon // train_batch_size
142142
self._horizon_number_activations = self._horizon_steps * train_batch_size
143143

144-
def step(self, neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE]) -> dict[str, Any]:
144+
def step(self, neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE]) -> dict[str, Any]:
145145
"""Step the neuron activity horizon data.
146146
147147
Args:
@@ -231,7 +231,7 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]:
231231
log = {}
232232

233233
for horizon_data in self._data:
234-
fired_count: Int[Tensor, Axis.LEARNT_FEATURE] = (
234+
fired_count: Int64[Tensor, Axis.LEARNT_FEATURE] = (
235235
(data.learned_activations > 0).sum(dim=0).detach().cpu()
236236
)
237237
horizon_specific_log = horizon_data.step(fired_count)

sparse_autoencoder/optimizer/abstract_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import ABC, abstractmethod
33
from typing import TypeAlias
44

5-
from jaxtyping import Int
5+
from jaxtyping import Int64
66
from torch import Tensor
77
from torch.nn.parameter import Parameter
88
from torch.optim import Optimizer
@@ -33,7 +33,7 @@ def reset_state_all_parameters(self) -> None:
3333
def reset_neurons_state(
3434
self,
3535
parameter: Parameter,
36-
neuron_indices: Int[Tensor, Axis.LEARNT_FEATURE_IDX],
36+
neuron_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX],
3737
axis: int,
3838
) -> None:
3939
"""Reset the state for specific neurons, on a specific parameter.

0 commit comments

Comments
 (0)