File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
sparse_autoencoder/autoencoder Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ def __init__(
70
70
self ,
71
71
n_input_features : int ,
72
72
n_learned_features : int ,
73
- geometric_median_dataset : InputOutputActivationVector ,
73
+ geometric_median_dataset : InputOutputActivationVector | None = None ,
74
74
) -> None :
75
75
"""Initialize the Sparse Autoencoder Model.
76
76
@@ -88,8 +88,12 @@ def __init__(
88
88
89
89
# Store the geometric median of the dataset (so that we can reset parameters). This is not a
90
90
# parameter itself (the tied bias parameter is used for that), so gradients are disabled.
91
- self .geometric_median_dataset = geometric_median_dataset .clone ()
92
- self .geometric_median_dataset .requires_grad = False
91
+ if geometric_median_dataset is not None :
92
+ self .geometric_median_dataset = geometric_median_dataset .clone ()
93
+ self .geometric_median_dataset .requires_grad = False
94
+ else :
95
+ self .geometric_median_dataset = torch .zeros (n_input_features )
96
+ self .geometric_median_dataset .requires_grad = False
93
97
94
98
# Initialize the tied bias
95
99
self .tied_bias = Parameter (torch .empty (n_input_features ))
You can’t perform that action at this time.
0 commit comments