Skip to content

Commit 04fa8ea

Browse files
authored
Make the geometric median dataset optional (#110)
1 parent e5e27d6 commit 04fa8ea

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

sparse_autoencoder/autoencoder/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self,
7171
n_input_features: int,
7272
n_learned_features: int,
73-
geometric_median_dataset: InputOutputActivationVector,
73+
geometric_median_dataset: InputOutputActivationVector | None = None,
7474
) -> None:
7575
"""Initialize the Sparse Autoencoder Model.
7676
@@ -88,8 +88,12 @@ def __init__(
8888

8989
# Store the geometric median of the dataset (so that we can reset parameters). This is not a
9090
# parameter itself (the tied bias parameter is used for that), so gradients are disabled.
91-
self.geometric_median_dataset = geometric_median_dataset.clone()
92-
self.geometric_median_dataset.requires_grad = False
91+
if geometric_median_dataset is not None:
92+
self.geometric_median_dataset = geometric_median_dataset.clone()
93+
self.geometric_median_dataset.requires_grad = False
94+
else:
95+
self.geometric_median_dataset = torch.zeros(n_input_features)
96+
self.geometric_median_dataset.requires_grad = False
9397

9498
# Initialize the tied bias
9599
self.tied_bias = Parameter(torch.empty(n_input_features))

0 commit comments

Comments
 (0)