3
3
4
4
import torch
5
5
6
- from sparse_autoencoder .loss .abstract_loss import AbstractLoss
6
+ from sparse_autoencoder .loss .abstract_loss import AbstractLoss , LossLogType , LossReductionType
7
7
from sparse_autoencoder .tensor_types import (
8
8
InputOutputActivationBatch ,
9
+ ItemTensor ,
9
10
LearnedActivationBatch ,
10
11
TrainBatchStatistic ,
11
12
)
@@ -23,13 +24,21 @@ class LearnedActivationsL1Loss(AbstractLoss):
23
24
>>> learned_activations = torch.tensor([[2.0, -3], [2.0, -3]])
24
25
>>> unused_activations = torch.zeros_like(learned_activations)
25
26
>>> # Returns loss and metrics to log
26
- >>> l1_loss(unused_activations, learned_activations, unused_activations)
27
- ( tensor(0.5000), {'LearnedActivationsL1Loss': 0.5} )
27
+ >>> l1_loss(unused_activations, learned_activations, unused_activations)[0]
28
+ tensor(0.5000)
28
29
"""
29
30
30
31
l1_coefficient : float
31
32
"""L1 coefficient."""
32
33
34
+ def log_name (self ) -> str :
35
+ """Log name.
36
+
37
+ Returns:
38
+ Name of the loss module for logging.
39
+ """
40
+ return "learned_activations_l1_loss_penalty"
41
+
33
42
def __init__ (self , l1_coefficient : float ) -> None :
34
43
"""Initialize the absolute error loss.
35
44
@@ -42,11 +51,33 @@ def __init__(self, l1_coefficient: float) -> None:
42
51
self .l1_coefficient = l1_coefficient
43
52
super ().__init__ ()
44
53
45
- def forward (
54
+ def _l1_loss (
46
55
self ,
47
56
source_activations : InputOutputActivationBatch , # noqa: ARG002
48
57
learned_activations : LearnedActivationBatch ,
49
58
decoded_activations : InputOutputActivationBatch , # noqa: ARG002
59
+ ) -> tuple [TrainBatchStatistic , TrainBatchStatistic ]:
60
+ """Learned activations L1 (absolute error) loss.
61
+
62
+ Args:
63
+ source_activations: Source activations (input activations to the autoencoder from the
64
+ source model).
65
+ learned_activations: Learned activations (intermediate activations in the autoencoder).
66
+ decoded_activations: Decoded activations.
67
+
68
+ Returns:
69
+ Tuple of itemwise absolute loss, and itemwise absolute loss multiplied by the l1
70
+ coefficient.
71
+ """
72
+ absolute_loss = torch .abs (learned_activations ).sum (dim = - 1 )
73
+ absolute_loss_penalty = absolute_loss * self .l1_coefficient
74
+ return absolute_loss , absolute_loss_penalty
75
+
76
+ def forward (
77
+ self ,
78
+ source_activations : InputOutputActivationBatch ,
79
+ learned_activations : LearnedActivationBatch ,
80
+ decoded_activations : InputOutputActivationBatch ,
50
81
) -> TrainBatchStatistic :
51
82
"""Learned activations L1 (absolute error) loss.
52
83
@@ -59,9 +90,48 @@ def forward(
59
90
Returns:
60
91
Loss per batch item.
61
92
"""
62
- absolute_loss = torch .abs (learned_activations )
93
+ return self ._l1_loss (source_activations , learned_activations , decoded_activations )[1 ]
94
+
95
+ # Override to add both the loss and the penalty to the log
96
+ def batch_scalar_loss_with_log (
97
+ self ,
98
+ source_activations : InputOutputActivationBatch ,
99
+ learned_activations : LearnedActivationBatch ,
100
+ decoded_activations : InputOutputActivationBatch ,
101
+ reduction : LossReductionType = LossReductionType .MEAN ,
102
+ ) -> tuple [ItemTensor , LossLogType ]:
103
+ """Learned activations L1 (absolute error) loss, with log.
104
+
105
+ Args:
106
+ source_activations: Source activations (input activations to the autoencoder from the
107
+ source model).
108
+ learned_activations: Learned activations (intermediate activations in the autoencoder).
109
+ decoded_activations: Decoded activations.
110
+ reduction: Loss reduction type. Typically you would choose LossReductionType.MEAN to
111
+ make the loss independent of the batch size.
112
+
113
+ Returns:
114
+ Tuple of the L1 absolute error batch scalar loss and a dict of the properties to log
115
+ (loss before and after the l1 coefficient).
116
+ """
117
+ absolute_loss , absolute_loss_penalty = self ._l1_loss (
118
+ source_activations , learned_activations , decoded_activations
119
+ )
120
+
121
+ match reduction :
122
+ case LossReductionType .MEAN :
123
+ batch_scalar_loss = absolute_loss .mean ().squeeze ()
124
+ batch_scalar_loss_penalty = absolute_loss_penalty .mean ().squeeze ()
125
+ case LossReductionType .SUM :
126
+ batch_scalar_loss = absolute_loss .sum ().squeeze ()
127
+ batch_scalar_loss_penalty = absolute_loss_penalty .sum ().squeeze ()
128
+
129
+ metrics = {
130
+ "learned_activations_l1_loss" : batch_scalar_loss .item (),
131
+ self .log_name (): batch_scalar_loss_penalty .item (),
132
+ }
63
133
64
- return absolute_loss . sum ( dim = - 1 ) * self . l1_coefficient
134
+ return batch_scalar_loss_penalty , metrics
65
135
66
136
def extra_repr (self ) -> str :
67
137
"""Extra representation string."""
0 commit comments