Skip to content

Commit a8fe5d0

Browse files
authored
Add a model reconstruction validation metric (#112)
1 parent b2c821f commit a8fe5d0

File tree

8 files changed

+269
-64
lines changed

8 files changed

+269
-64
lines changed

docs/content/demo.ipynb

Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,27 @@
107107
" # and we have found that 4x is a good starting point.\n",
108108
" \"expansion_factor\": 4,\n",
109109
" # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).\n",
110-
" \"l1_coefficient\": 3e-4,\n",
110+
" \"l1_coefficient\": 1e-3,\n",
111111
" # Adam parameters (set to the default ones here)\n",
112-
" \"lr\": 1e-4,\n",
112+
" \"lr\": 3e-4,\n",
113113
" \"adam_beta_1\": 0.9,\n",
114114
" \"adam_beta_2\": 0.999,\n",
115115
" \"adam_epsilon\": 1e-8,\n",
116116
" \"adam_weight_decay\": 0.0,\n",
117117
" # Batch sizes\n",
118118
" \"train_batch_size\": 4096,\n",
119119
" \"context_size\": 128,\n",
120+
" # Source model hook point\n",
121+
" \"source_model_name\": \"gelu-2l\",\n",
122+
" \"source_model_dtype\": \"float32\",\n",
123+
" \"source_model_hook_point\": \"blocks.0.hook_mlp_out\",\n",
124+
" \"source_model_hook_point_layer\": 0,\n",
125+
" # Train pipeline parameters\n",
126+
" \"max_store_size\": 384 * 4096 * 2,\n",
127+
" \"max_activations\": 2_000_000_000,\n",
128+
" \"resample_frequency\": 122_880_000,\n",
129+
" \"checkpoint_frequency\": 100_000_000,\n",
130+
" \"validation_frequency\": 384 * 4096 * 2 * 100, # Every 100 generations\n",
120131
"}"
121132
]
122133
},
@@ -141,7 +152,7 @@
141152
},
142153
{
143154
"cell_type": "code",
144-
"execution_count": 3,
155+
"execution_count": 4,
145156
"metadata": {},
146157
"outputs": [
147158
{
@@ -157,22 +168,22 @@
157168
"'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'"
158169
]
159170
},
160-
"execution_count": 3,
171+
"execution_count": 4,
161172
"metadata": {},
162173
"output_type": "execute_result"
163174
}
164175
],
165176
"source": [
166177
"# Source model setup with TransformerLens\n",
167-
"src_model_name = \"gelu-2l\"\n",
168-
"src_model = HookedTransformer.from_pretrained(src_model_name, dtype=\"float32\")\n",
178+
"src_model = HookedTransformer.from_pretrained(\n",
179+
" str(hyperparameters[\"source_model_name\"]), dtype=str(hyperparameters[\"source_model_dtype\"])\n",
180+
")\n",
169181
"\n",
170182
"# Details about the activations we'll train the sparse autoencoder on\n",
171-
"src_model_activation_hook_point = \"blocks.0.hook_mlp_out\"\n",
172-
"src_model_activation_layer = 0\n",
173183
"autoencoder_input_dim: int = src_model.cfg.d_model # type: ignore (TransformerLens typing is currently broken)\n",
174184
"\n",
175-
"f\"Source: {src_model_name}, Hook: {src_model_activation_hook_point}, \\\n",
185+
"f\"Source: {hyperparameters['source_model_name']}, \\\n",
186+
" Hook: {hyperparameters['source_model_hook_point']}, \\\n",
176187
" Features: {autoencoder_input_dim}\""
177188
]
178189
},
@@ -199,7 +210,7 @@
199210
},
200211
{
201212
"cell_type": "code",
202-
"execution_count": 4,
213+
"execution_count": 5,
203214
"metadata": {},
204215
"outputs": [
205216
{
@@ -216,7 +227,7 @@
216227
")"
217228
]
218229
},
219-
"execution_count": 4,
230+
"execution_count": 5,
220231
"metadata": {},
221232
"output_type": "execute_result"
222233
}
@@ -244,19 +255,19 @@
244255
},
245256
{
246257
"cell_type": "code",
247-
"execution_count": 5,
258+
"execution_count": 6,
248259
"metadata": {},
249260
"outputs": [
250261
{
251262
"data": {
252263
"text/plain": [
253264
"LossReducer(\n",
254-
" (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n",
265+
" (0): LearnedActivationsL1Loss(l1_coefficient=0.001)\n",
255266
" (1): L2ReconstructionLoss()\n",
256267
")"
257268
]
258269
},
259-
"execution_count": 5,
270+
"execution_count": 6,
260271
"metadata": {},
261272
"output_type": "execute_result"
262273
}
@@ -265,7 +276,7 @@
265276
"# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n",
266277
"loss = LossReducer(\n",
267278
" LearnedActivationsL1Loss(\n",
268-
" l1_coefficient=hyperparameters[\"l1_coefficient\"],\n",
279+
" l1_coefficient=float(hyperparameters[\"l1_coefficient\"]),\n",
269280
" ),\n",
270281
" L2ReconstructionLoss(),\n",
271282
")\n",
@@ -274,7 +285,7 @@
274285
},
275286
{
276287
"cell_type": "code",
277-
"execution_count": 6,
288+
"execution_count": 7,
278289
"metadata": {},
279290
"outputs": [
280291
{
@@ -289,13 +300,13 @@
289300
" eps: 1e-08\n",
290301
" foreach: None\n",
291302
" fused: None\n",
292-
" lr: 0.0001\n",
303+
" lr: 0.0003\n",
293304
" maximize: False\n",
294305
" weight_decay: 0.0\n",
295306
")"
296307
]
297308
},
298-
"execution_count": 6,
309+
"execution_count": 7,
299310
"metadata": {},
300311
"output_type": "execute_result"
301312
}
@@ -304,10 +315,10 @@
304315
"optimizer = AdamWithReset(\n",
305316
" params=autoencoder.parameters(),\n",
306317
" named_parameters=autoencoder.named_parameters(),\n",
307-
" lr=hyperparameters[\"lr\"],\n",
308-
" betas=(hyperparameters[\"adam_beta_1\"], hyperparameters[\"adam_beta_2\"]),\n",
309-
" eps=hyperparameters[\"adam_epsilon\"],\n",
310-
" weight_decay=hyperparameters[\"adam_weight_decay\"],\n",
318+
" lr=float(hyperparameters[\"lr\"]),\n",
319+
" betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n",
320+
" eps=float(hyperparameters[\"adam_epsilon\"]),\n",
321+
" weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n",
311322
")\n",
312323
"optimizer"
313324
]
@@ -321,7 +332,7 @@
321332
},
322333
{
323334
"cell_type": "code",
324-
"execution_count": 7,
335+
"execution_count": 8,
325336
"metadata": {},
326337
"outputs": [],
327338
"source": [
@@ -345,13 +356,13 @@
345356
},
346357
{
347358
"cell_type": "code",
348-
"execution_count": 8,
359+
"execution_count": 9,
349360
"metadata": {},
350361
"outputs": [
351362
{
352363
"data": {
353364
"application/vnd.jupyter.widget-view+json": {
354-
"model_id": "4bdf3ebe364243bd8f881933e56c997d",
365+
"model_id": "75e636ebb9e04b279c7216c74496538d",
355366
"version_major": 2,
356367
"version_minor": 0
357368
},
@@ -390,7 +401,7 @@
390401
},
391402
{
392403
"cell_type": "code",
393-
"execution_count": 9,
404+
"execution_count": 10,
394405
"metadata": {},
395406
"outputs": [],
396407
"source": [
@@ -400,7 +411,7 @@
400411
},
401412
{
402413
"cell_type": "code",
403-
"execution_count": 10,
414+
"execution_count": 11,
404415
"metadata": {},
405416
"outputs": [
406417
{
@@ -426,7 +437,7 @@
426437
{
427438
"data": {
428439
"text/html": [
429-
"Run data is saved locally in <code>.cache/wandb/run-20231126_122954-xsruek7y</code>"
440+
"Run data is saved locally in <code>.cache/wandb/run-20231126_184500-2fnpg8zi</code>"
430441
],
431442
"text/plain": [
432443
"<IPython.core.display.HTML object>"
@@ -438,7 +449,7 @@
438449
{
439450
"data": {
440451
"text/html": [
441-
"Syncing run <strong><a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y' target=\"_blank\">vivid-totem-95</a></strong> to <a href='https://wandb.ai/alan-cooney/sparse-autoencoder' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
452+
"Syncing run <strong><a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi' target=\"_blank\">prime-star-105</a></strong> to <a href='https://wandb.ai/alan-cooney/sparse-autoencoder' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
442453
],
443454
"text/plain": [
444455
"<IPython.core.display.HTML object>"
@@ -462,7 +473,7 @@
462473
{
463474
"data": {
464475
"text/html": [
465-
" View run at <a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y' target=\"_blank\">https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y</a>"
476+
" View run at <a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi' target=\"_blank\">https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi</a>"
466477
],
467478
"text/plain": [
468479
"<IPython.core.display.HTML object>"
@@ -474,13 +485,13 @@
474485
{
475486
"data": {
476487
"text/html": [
477-
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
488+
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
478489
],
479490
"text/plain": [
480-
"<wandb.sdk.wandb_run.Run at 0x2ff1cbcd0>"
491+
"<wandb.sdk.wandb_run.Run at 0x3154cec10>"
481492
]
482493
},
483-
"execution_count": 10,
494+
"execution_count": 11,
484495
"metadata": {},
485496
"output_type": "execute_result"
486497
}
@@ -496,13 +507,13 @@
496507
},
497508
{
498509
"cell_type": "code",
499-
"execution_count": 11,
510+
"execution_count": 12,
500511
"metadata": {},
501512
"outputs": [
502513
{
503514
"data": {
504515
"application/vnd.jupyter.widget-view+json": {
505-
"model_id": "e1e6fa019f524f3da19708a4eda9b349",
516+
"model_id": "1322f5e5dd5c4507a6eca9aa1f010882",
506517
"version_major": 2,
507518
"version_minor": 0
508519
},
@@ -526,9 +537,9 @@
526537
"pipeline = Pipeline(\n",
527538
" activation_resampler=activation_resampler,\n",
528539
" autoencoder=autoencoder,\n",
529-
" cache_name=src_model_activation_hook_point,\n",
540+
" cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n",
530541
" checkpoint_directory=checkpoint_path,\n",
531-
" layer=src_model_activation_layer,\n",
542+
" layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n",
532543
" loss=loss,\n",
533544
" optimizer=optimizer,\n",
534545
" source_data_batch_size=6,\n",
@@ -538,11 +549,11 @@
538549
"\n",
539550
"pipeline.run_pipeline(\n",
540551
" train_batch_size=int(hyperparameters[\"train_batch_size\"]),\n",
541-
" max_store_size=384 * 4096 * 2,\n",
542-
" # Sizes for demo purposes (you probably want to scale these by 10x)\n",
543-
" max_activations=2_000_000_000,\n",
544-
" resample_frequency=122_880_000,\n",
545-
" checkpoint_frequency=100_000_000,\n",
552+
" max_store_size=int(hyperparameters[\"max_store_size\"]),\n",
553+
" max_activations=int(hyperparameters[\"max_activations\"]),\n",
554+
" resample_frequency=int(hyperparameters[\"resample_frequency\"]),\n",
555+
" checkpoint_frequency=int(hyperparameters[\"checkpoint_frequency\"]),\n",
556+
" validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n",
546557
")"
547558
]
548559
},
@@ -554,18 +565,6 @@
554565
"source": [
555566
"wandb.finish()"
556567
]
557-
},
558-
{
559-
"cell_type": "markdown",
560-
"metadata": {},
561-
"source": [
562-
"## Training Advice\n",
563-
"\n",
564-
"-- Unfinished --\n",
565-
"\n",
566-
"- Check recovery loss is low while sparsity is low as well (<20 L1) usually.\n",
567-
"- Can't be sure features are useful until you dig into them more. "
568-
]
569568
}
570569
],
571570
"metadata": {

sparse_autoencoder/metrics/metrics_container.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
99
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
1010
from sparse_autoencoder.metrics.validate.abstract_validate_metric import AbstractValidationMetric
11+
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore
1112

1213

1314
@dataclass
@@ -33,5 +34,6 @@ class MetricsContainer:
3334
default_metrics = MetricsContainer(
3435
train_metrics=[TrainBatchFeatureDensityMetric(), CapacityMetric()],
3536
resample_metrics=[NeuronActivityMetric()],
37+
validation_metrics=[ModelReconstructionScore()],
3638
)
3739
"""Default metrics container."""

sparse_autoencoder/metrics/validate/abstract_validate_metric.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
from dataclasses import dataclass
44
from typing import Any
55

6+
from sparse_autoencoder.tensor_types import ValidationStatistics
7+
68

79
@dataclass
810
class ValidationMetricData:
911
"""Validation metric data."""
1012

11-
source_model_loss: float
13+
source_model_loss: ValidationStatistics
14+
15+
source_model_loss_with_reconstruction: ValidationStatistics
1216

13-
autoencoder_loss: float
17+
source_model_loss_with_zero_ablation: ValidationStatistics
1418

1519

1620
class AbstractValidationMetric(ABC):

0 commit comments

Comments
 (0)