Skip to content

Commit 9a0052a

Browse files
authored
Improve log names (#109)
1 parent 04fa8ea commit 9a0052a

File tree

7 files changed

+114
-15
lines changed

7 files changed

+114
-15
lines changed

sparse_autoencoder/loss/abstract_loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ class AbstractLoss(Module, ABC):
3737
_modules: dict[str, "AbstractLoss"] # type: ignore[assignment] (narrowing)
3838
"""Children loss modules."""
3939

40+
@abstractmethod
41+
def log_name(self) -> str:
42+
"""Log name.
43+
44+
Returns:
45+
Name of the loss module for logging.
46+
"""
47+
4048
@abstractmethod
4149
def forward(
4250
self,
@@ -85,7 +93,6 @@ def batch_scalar_loss(
8593
case LossReductionType.SUM:
8694
return itemwise_loss.sum().squeeze()
8795

88-
@final
8996
def batch_scalar_loss_with_log(
9097
self,
9198
source_activations: InputOutputActivationBatch,
@@ -131,8 +138,8 @@ def batch_scalar_loss_with_log(
131138
)
132139

133140
# Add in the current loss module's metric
134-
class_name = self.__class__.__name__
135-
metrics[class_name] = current_module_loss.detach().cpu().item()
141+
log_name = self.log_name()
142+
metrics[log_name] = current_module_loss.detach().cpu().item()
136143

137144
return current_module_loss, metrics
138145

sparse_autoencoder/loss/decoded_activations_l2.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,17 @@ class L2ReconstructionLoss(AbstractLoss):
2929
>>> unused_activations = torch.zeros_like(input_activations)
3030
>>> # Outputs both loss and metrics to log
3131
>>> loss(input_activations, unused_activations, output_activations)
32-
(tensor(11.), {'L2ReconstructionLoss': 11.0})
32+
(tensor(11.), {'l2_reconstruction_loss': 11.0})
3333
"""
3434

35+
def log_name(self) -> str:
36+
"""Log name.
37+
38+
Returns:
39+
Name of the loss module for logging.
40+
"""
41+
return "l2_reconstruction_loss"
42+
3543
def forward(
3644
self,
3745
source_activations: InputOutputActivationBatch,

sparse_autoencoder/loss/learned_activations_l1.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import torch
55

6-
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
6+
from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossLogType, LossReductionType
77
from sparse_autoencoder.tensor_types import (
88
InputOutputActivationBatch,
9+
ItemTensor,
910
LearnedActivationBatch,
1011
TrainBatchStatistic,
1112
)
@@ -23,13 +24,21 @@ class LearnedActivationsL1Loss(AbstractLoss):
2324
>>> learned_activations = torch.tensor([[2.0, -3], [2.0, -3]])
2425
>>> unused_activations = torch.zeros_like(learned_activations)
2526
>>> # Returns loss and metrics to log
26-
>>> l1_loss(unused_activations, learned_activations, unused_activations)
27-
(tensor(0.5000), {'LearnedActivationsL1Loss': 0.5})
27+
>>> l1_loss(unused_activations, learned_activations, unused_activations)[0]
28+
tensor(0.5000)
2829
"""
2930

3031
l1_coefficient: float
3132
"""L1 coefficient."""
3233

34+
def log_name(self) -> str:
35+
"""Log name.
36+
37+
Returns:
38+
Name of the loss module for logging.
39+
"""
40+
return "learned_activations_l1_loss_penalty"
41+
3342
def __init__(self, l1_coefficient: float) -> None:
3443
"""Initialize the absolute error loss.
3544
@@ -42,11 +51,33 @@ def __init__(self, l1_coefficient: float) -> None:
4251
self.l1_coefficient = l1_coefficient
4352
super().__init__()
4453

45-
def forward(
54+
def _l1_loss(
4655
self,
4756
source_activations: InputOutputActivationBatch, # noqa: ARG002
4857
learned_activations: LearnedActivationBatch,
4958
decoded_activations: InputOutputActivationBatch, # noqa: ARG002
59+
) -> tuple[TrainBatchStatistic, TrainBatchStatistic]:
60+
"""Learned activations L1 (absolute error) loss.
61+
62+
Args:
63+
source_activations: Source activations (input activations to the autoencoder from the
64+
source model).
65+
learned_activations: Learned activations (intermediate activations in the autoencoder).
66+
decoded_activations: Decoded activations.
67+
68+
Returns:
69+
Tuple of itemwise absolute loss, and itemwise absolute loss multiplied by the l1
70+
coefficient.
71+
"""
72+
absolute_loss = torch.abs(learned_activations).sum(dim=-1)
73+
absolute_loss_penalty = absolute_loss * self.l1_coefficient
74+
return absolute_loss, absolute_loss_penalty
75+
76+
def forward(
77+
self,
78+
source_activations: InputOutputActivationBatch,
79+
learned_activations: LearnedActivationBatch,
80+
decoded_activations: InputOutputActivationBatch,
5081
) -> TrainBatchStatistic:
5182
"""Learned activations L1 (absolute error) loss.
5283
@@ -59,9 +90,48 @@ def forward(
5990
Returns:
6091
Loss per batch item.
6192
"""
62-
absolute_loss = torch.abs(learned_activations)
93+
return self._l1_loss(source_activations, learned_activations, decoded_activations)[1]
94+
95+
# Override to add both the loss and the penalty to the log
96+
def batch_scalar_loss_with_log(
97+
self,
98+
source_activations: InputOutputActivationBatch,
99+
learned_activations: LearnedActivationBatch,
100+
decoded_activations: InputOutputActivationBatch,
101+
reduction: LossReductionType = LossReductionType.MEAN,
102+
) -> tuple[ItemTensor, LossLogType]:
103+
"""Learned activations L1 (absolute error) loss, with log.
104+
105+
Args:
106+
source_activations: Source activations (input activations to the autoencoder from the
107+
source model).
108+
learned_activations: Learned activations (intermediate activations in the autoencoder).
109+
decoded_activations: Decoded activations.
110+
reduction: Loss reduction type. Typically you would choose LossReductionType.MEAN to
111+
make the loss independent of the batch size.
112+
113+
Returns:
114+
Tuple of the L1 absolute error batch scalar loss and a dict of the properties to log
115+
(loss before and after the l1 coefficient).
116+
"""
117+
absolute_loss, absolute_loss_penalty = self._l1_loss(
118+
source_activations, learned_activations, decoded_activations
119+
)
120+
121+
match reduction:
122+
case LossReductionType.MEAN:
123+
batch_scalar_loss = absolute_loss.mean().squeeze()
124+
batch_scalar_loss_penalty = absolute_loss_penalty.mean().squeeze()
125+
case LossReductionType.SUM:
126+
batch_scalar_loss = absolute_loss.sum().squeeze()
127+
batch_scalar_loss_penalty = absolute_loss_penalty.sum().squeeze()
128+
129+
metrics = {
130+
"learned_activations_l1_loss": batch_scalar_loss.item(),
131+
self.log_name(): batch_scalar_loss_penalty.item(),
132+
}
63133

64-
return absolute_loss.sum(dim=-1) * self.l1_coefficient
134+
return batch_scalar_loss_penalty, metrics
65135

66136
def extra_repr(self) -> str:
67137
"""Extra representation string."""

sparse_autoencoder/loss/reducer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ class LossReducer(AbstractLoss):
3838
_modules: dict[str, "AbstractLoss"]
3939
"""Children loss modules."""
4040

41+
def log_name(self) -> str:
42+
"""Log name.
43+
44+
Returns:
45+
Name of the loss module for logging.
46+
"""
47+
return "total_loss"
48+
4149
def __init__(
4250
self,
4351
*loss_modules: AbstractLoss,

sparse_autoencoder/loss/tests/test_abstract_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def forward(
2323
# Simple dummy implementation for testing
2424
return torch.tensor([1.0, 2.0, 3.0])
2525

26+
def log_name(self) -> str:
27+
"""Log name."""
28+
return "dummy"
29+
2630

2731
@pytest.fixture()
2832
def dummy_loss() -> DummyLoss:
@@ -61,15 +65,13 @@ def test_batch_scalar_loss_with_log(dummy_loss: DummyLoss) -> None:
6165
_loss, log = dummy_loss.batch_scalar_loss_with_log(
6266
source_activations, learned_activations, decoded_activations
6367
)
64-
assert "DummyLoss" in log
6568
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
66-
assert log["DummyLoss"] == expected
69+
assert log["dummy"] == expected
6770

6871

6972
def test_call_method(dummy_loss: DummyLoss) -> None:
7073
"""Test the call method."""
7174
source_activations = learned_activations = decoded_activations = torch.ones((1, 3))
7275
_loss, log = dummy_loss(source_activations, learned_activations, decoded_activations)
73-
assert "DummyLoss" in log
7476
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
75-
assert log["DummyLoss"] == expected
77+
assert log["dummy"] == expected

sparse_autoencoder/metrics/train/feature_density.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class TrainBatchFeatureDensityMetric(AbstractTrainMetric):
2121
Percentage of samples in which each feature was active (i.e. the neuron has "fired"), in a
2222
training batch.
2323
24+
Generally we want a small number of features to be active in each batch, so average feature
25+
density should be low. By contrast if the average feature density is high, it means that the
26+
features are not sparse enough.
27+
2428
Warning:
2529
This is not the same as the feature density of the entire training set. It's main use is
2630
tracking the progress of training.

sparse_autoencoder/train/abstract_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def run_pipeline(
264264
last_checkpoint += num_activation_vectors_in_store
265265
total_activations += num_activation_vectors_in_store
266266
if wandb.run is not None:
267-
wandb.log({"total_activations": total_activations}, commit=False)
267+
wandb.log({"activations_generated": total_activations}, commit=False)
268268

269269
# Train
270270
progress_bar.set_postfix({"stage": "train"})

0 commit comments

Comments
 (0)