Skip to content

Commit 3e94705

Browse files
authored
Reduce model complexity (#171)
Remove the abstract encoder/decoder classes and just have one interface for the model itself
1 parent 5464e8f commit 3e94705

21 files changed

+397
-866
lines changed

docs/content/demo.ipynb

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
"# Quick Start Training Demo\n",
1717
"\n",
1818
"This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
19-
"hyperparameters (like the model to train on), and then set it off.\n",
20-
"By default it replicates Neel Nanda's\n",
21-
"[comment on the Anthropic dictionary learning\n",
22-
"paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda)."
19+
"hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n",
20+
"MLP layers from GPT2 small."
2321
]
2422
},
2523
{
@@ -75,6 +73,7 @@
7573
" Method,\n",
7674
" OptimizerHyperparameters,\n",
7775
" Parameter,\n",
76+
" PipelineHyperparameters,\n",
7877
" SourceDataHyperparameters,\n",
7978
" SourceModelHyperparameters,\n",
8079
" sweep,\n",
@@ -103,26 +102,40 @@
103102
},
104103
{
105104
"cell_type": "code",
106-
"execution_count": 3,
105+
"execution_count": 5,
107106
"metadata": {},
108107
"outputs": [
109108
{
110-
"ename": "TypeError",
111-
"evalue": "SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'",
112-
"output_type": "error",
113-
"traceback": [
114-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
115-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
116-
"Cell \u001b[0;32mIn[3], line 12\u001b[0m\n\u001b[1;32m 1\u001b[0m sweep_config \u001b[38;5;241m=\u001b[39m SweepConfig(\n\u001b[1;32m 2\u001b[0m parameters\u001b[38;5;241m=\u001b[39mHyperparameters(\n\u001b[1;32m 3\u001b[0m activation_resampler\u001b[38;5;241m=\u001b[39mActivationResamplerHyperparameters(\n\u001b[1;32m 4\u001b[0m threshold_is_dead_portion_fires\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;241m1e-6\u001b[39m),\n\u001b[1;32m 5\u001b[0m ),\n\u001b[1;32m 6\u001b[0m loss\u001b[38;5;241m=\u001b[39mLossHyperparameters(\n\u001b[1;32m 7\u001b[0m l1_coefficient\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-2\u001b[39m, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4e-3\u001b[39m),\n\u001b[1;32m 8\u001b[0m ),\n\u001b[1;32m 9\u001b[0m optimizer\u001b[38;5;241m=\u001b[39mOptimizerHyperparameters(\n\u001b[1;32m 10\u001b[0m lr\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-3\u001b[39m, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m),\n\u001b[1;32m 11\u001b[0m ),\n\u001b[0;32m---> 12\u001b[0m source_model\u001b[38;5;241m=\u001b[39m\u001b[43mSourceModelHyperparameters\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgelu-2l\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmlp_out\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43mhook_layer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mhook_dimension\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m512\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 18\u001b[0m source_data\u001b[38;5;241m=\u001b[39mSourceDataHyperparameters(\n\u001b[1;32m 19\u001b[0m dataset_path\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeelNanda/c4-code-tokenized-2b\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 20\u001b[0m ),\n\u001b[1;32m 21\u001b[0m ),\n\u001b[1;32m 22\u001b[0m method\u001b[38;5;241m=\u001b[39mMethod\u001b[38;5;241m.\u001b[39mRANDOM,\n\u001b[1;32m 23\u001b[0m )\n\u001b[1;32m 24\u001b[0m sweep_config\n",
117-
"\u001b[0;31mTypeError\u001b[0m: SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'"
118-
]
109+
"data": {
110+
"text/plain": [
111+
"SweepConfig(parameters=Hyperparameters(\n",
112+
" source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n",
113+
" source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n",
114+
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
115+
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
116+
" loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
117+
" optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
118+
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n",
119+
" random_seed=Parameter(value=49)\n",
120+
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
121+
]
122+
},
123+
"execution_count": 5,
124+
"metadata": {},
125+
"output_type": "execute_result"
119126
}
120127
],
121128
"source": [
129+
"n_layers_gpt2_small = 12\n",
130+
"\n",
122131
"sweep_config = SweepConfig(\n",
123132
" parameters=Hyperparameters(\n",
124133
" activation_resampler=ActivationResamplerHyperparameters(\n",
134+
" resample_interval=Parameter(200_000_000),\n",
135+
" n_activations_activity_collate=Parameter(100_000_000),\n",
125136
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
137+
" max_n_resamples=Parameter(4),\n",
138+
" resample_dataset_size=Parameter(200_000),\n",
126139
" ),\n",
127140
" loss=LossHyperparameters(\n",
128141
" l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
@@ -131,12 +144,24 @@
131144
" lr=Parameter(max=1e-3, min=1e-5),\n",
132145
" ),\n",
133146
" source_model=SourceModelHyperparameters(\n",
134-
" name=Parameter(\"gelu-2l\"),\n",
135-
" cache_names=Parameter([\"blocks.0.hook_mlp_out\", \"blocks.1.hook_mlp_out\"]),\n",
136-
" hook_dimension=Parameter(512),\n",
147+
" name=Parameter(\"gpt2-small\"),\n",
148+
" # Train in parallel on all MLP layers\n",
149+
" cache_names=Parameter(\n",
150+
" [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n",
151+
" ),\n",
152+
" hook_dimension=Parameter(768),\n",
137153
" ),\n",
138154
" source_data=SourceDataHyperparameters(\n",
139-
" dataset_path=Parameter(\"NeelNanda/c4-code-tokenized-2b\"),\n",
155+
" dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
156+
" context_size=Parameter(128),\n",
157+
" pre_tokenized=Parameter(value=True),\n",
158+
" ),\n",
159+
" pipeline=PipelineHyperparameters(\n",
160+
" max_activations=Parameter(1_000_000_000),\n",
161+
" checkpoint_frequency=Parameter(100_000_000),\n",
162+
" validation_frequency=Parameter(100_000_000),\n",
163+
" train_batch_size=Parameter(1024),\n",
164+
" max_store_size=Parameter(300_000),\n",
140165
" ),\n",
141166
" ),\n",
142167
" method=Method.RANDOM,\n",

docs/content/flexible_demo.ipynb

Lines changed: 21 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
},
5656
{
5757
"cell_type": "code",
58-
"execution_count": 1,
58+
"execution_count": null,
5959
"metadata": {},
6060
"outputs": [],
6161
"source": [
@@ -83,17 +83,9 @@
8383
},
8484
{
8585
"cell_type": "code",
86-
"execution_count": 2,
86+
"execution_count": null,
8787
"metadata": {},
88-
"outputs": [
89-
{
90-
"name": "stdout",
91-
"output_type": "stream",
92-
"text": [
93-
"Using device: mps\n"
94-
]
95-
}
96-
],
88+
"outputs": [],
9789
"source": [
9890
"import os\n",
9991
"from pathlib import Path\n",
@@ -140,7 +132,7 @@
140132
},
141133
{
142134
"cell_type": "code",
143-
"execution_count": 3,
135+
"execution_count": null,
144136
"metadata": {},
145137
"outputs": [],
146138
"source": [
@@ -197,27 +189,9 @@
197189
},
198190
{
199191
"cell_type": "code",
200-
"execution_count": 4,
192+
"execution_count": null,
201193
"metadata": {},
202-
"outputs": [
203-
{
204-
"name": "stdout",
205-
"output_type": "stream",
206-
"text": [
207-
"Loaded pretrained model gelu-2l into HookedTransformer\n"
208-
]
209-
},
210-
{
211-
"data": {
212-
"text/plain": [
213-
"'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'"
214-
]
215-
},
216-
"execution_count": 4,
217-
"metadata": {},
218-
"output_type": "execute_result"
219-
}
220-
],
194+
"outputs": [],
221195
"source": [
222196
"# Source model setup with TransformerLens\n",
223197
"src_model = HookedTransformer.from_pretrained(\n",
@@ -255,28 +229,9 @@
255229
},
256230
{
257231
"cell_type": "code",
258-
"execution_count": 5,
232+
"execution_count": null,
259233
"metadata": {},
260-
"outputs": [
261-
{
262-
"data": {
263-
"text/plain": [
264-
"SparseAutoencoder(\n",
265-
" (_pre_encoder_bias): TiedBias(position=pre_encoder)\n",
266-
" (_encoder): LinearEncoder(\n",
267-
" in_features=512, out_features=2048\n",
268-
" (activation_function): ReLU()\n",
269-
" )\n",
270-
" (_decoder): UnitNormDecoder(in_features=2048, out_features=512)\n",
271-
" (_post_decoder_bias): TiedBias(position=post_decoder)\n",
272-
")"
273-
]
274-
},
275-
"execution_count": 5,
276-
"metadata": {},
277-
"output_type": "execute_result"
278-
}
279-
],
234+
"outputs": [],
280235
"source": [
281236
"expansion_factor = hyperparameters[\"expansion_factor\"]\n",
282237
"autoencoder = SparseAutoencoder(\n",
@@ -297,23 +252,9 @@
297252
},
298253
{
299254
"cell_type": "code",
300-
"execution_count": 6,
255+
"execution_count": null,
301256
"metadata": {},
302-
"outputs": [
303-
{
304-
"data": {
305-
"text/plain": [
306-
"LossReducer(\n",
307-
" (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n",
308-
" (1): L2ReconstructionLoss()\n",
309-
")"
310-
]
311-
},
312-
"execution_count": 6,
313-
"metadata": {},
314-
"output_type": "execute_result"
315-
}
316-
],
257+
"outputs": [],
317258
"source": [
318259
"# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n",
319260
"loss = LossReducer(\n",
@@ -327,32 +268,9 @@
327268
},
328269
{
329270
"cell_type": "code",
330-
"execution_count": 7,
271+
"execution_count": null,
331272
"metadata": {},
332-
"outputs": [
333-
{
334-
"data": {
335-
"text/plain": [
336-
"AdamWithReset (\n",
337-
"Parameter Group 0\n",
338-
" amsgrad: False\n",
339-
" betas: (0.9, 0.999)\n",
340-
" capturable: False\n",
341-
" differentiable: False\n",
342-
" eps: 1e-08\n",
343-
" foreach: None\n",
344-
" fused: None\n",
345-
" lr: 0.0001\n",
346-
" maximize: False\n",
347-
" weight_decay: 0.0\n",
348-
")"
349-
]
350-
},
351-
"execution_count": 7,
352-
"metadata": {},
353-
"output_type": "execute_result"
354-
}
355-
],
273+
"outputs": [],
356274
"source": [
357275
"optimizer = AdamWithReset(\n",
358276
" params=autoencoder.parameters(),\n",
@@ -361,6 +279,7 @@
361279
" betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n",
362280
" eps=float(hyperparameters[\"adam_epsilon\"]),\n",
363281
" weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n",
282+
" has_components_dim=True,\n",
364283
")\n",
365284
"optimizer"
366285
]
@@ -374,7 +293,7 @@
374293
},
375294
{
376295
"cell_type": "code",
377-
"execution_count": 8,
296+
"execution_count": null,
378297
"metadata": {},
379298
"outputs": [],
380299
"source": [
@@ -403,27 +322,13 @@
403322
},
404323
{
405324
"cell_type": "code",
406-
"execution_count": 9,
325+
"execution_count": null,
407326
"metadata": {},
408-
"outputs": [
409-
{
410-
"data": {
411-
"application/vnd.jupyter.widget-view+json": {
412-
"model_id": "2fe4955deca9463dbed606c9452d518e",
413-
"version_major": 2,
414-
"version_minor": 0
415-
},
416-
"text/plain": [
417-
"Resolving data files: 0%| | 0/28 [00:00<?, ?it/s]"
418-
]
419-
},
420-
"metadata": {},
421-
"output_type": "display_data"
422-
}
423-
],
327+
"outputs": [],
424328
"source": [
425329
"source_data = PreTokenizedDataset(\n",
426-
" dataset_path=\"NeelNanda/c4-code-tokenized-2b\", context_size=int(hyperparameters[\"context_size\"])\n",
330+
" dataset_path=\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\",\n",
331+
" context_size=int(hyperparameters[\"context_size\"]),\n",
427332
")"
428333
]
429334
},
@@ -447,7 +352,7 @@
447352
},
448353
{
449354
"cell_type": "code",
450-
"execution_count": 10,
355+
"execution_count": null,
451356
"metadata": {},
452357
"outputs": [],
453358
"source": [
@@ -471,14 +376,14 @@
471376
},
472377
{
473378
"cell_type": "code",
474-
"execution_count": null,
379+
"execution_count": 16,
475380
"metadata": {},
476381
"outputs": [],
477382
"source": [
478383
"pipeline = Pipeline(\n",
479384
" activation_resampler=activation_resampler,\n",
480385
" autoencoder=autoencoder,\n",
481-
" cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n",
386+
" cache_names=[str(hyperparameters[\"source_model_hook_point\"])],\n",
482387
" checkpoint_directory=checkpoint_path,\n",
483388
" layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n",
484389
" loss=loss,\n",
@@ -496,15 +401,6 @@
496401
" validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n",
497402
")"
498403
]
499-
},
500-
{
501-
"cell_type": "code",
502-
"execution_count": null,
503-
"metadata": {},
504-
"outputs": [],
505-
"source": [
506-
"wandb.finish()"
507-
]
508404
}
509405
],
510406
"metadata": {

docs/content/index.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,4 @@ The library is designed to be modular. By default it takes the approach from [To
4747
Monosemanticity: Decomposing Language Models With Dictionary Learning
4848
](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install
4949
the library and get started quickly. Then when you need to customise something, you can just extend
50-
the abstract class for that component (e.g. you can extend
51-
[`AbstractEncoder`][sparse_autoencoder.autoencoder.components.abstract_encoder] if you want to
52-
customise the encoder layer, and then easily drop it in the standard
53-
[`SparseAutoencoder`][sparse_autoencoder.autoencoder.model] model to keep everything else as is.
54-
Every component is fully documented, so it's nice and easy to do this.
50+
the abstract class for that component (every component is documented so that it's easy to do this).

sparse_autoencoder/activation_resampler/tests/test_activation_resampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_updates_dead_neuron_parameters(
343343
# Check the updated ones have changed
344344
for component_idx, neuron_idx in dead_neurons:
345345
# Decoder
346-
decoder_weights = current_parameters["_decoder._weight"]
346+
decoder_weights = current_parameters["decoder._weight"]
347347
current_dead_neuron_weights = decoder_weights[component_idx, neuron_idx]
348348
updated_dead_decoder_weights = parameter_updates[
349349
component_idx
@@ -353,7 +353,7 @@ def test_updates_dead_neuron_parameters(
353353
), "Dead decoder weights should have changed."
354354

355355
# Encoder
356-
current_dead_encoder_weights = current_parameters["_encoder._weight"][
356+
current_dead_encoder_weights = current_parameters["encoder._weight"][
357357
component_idx, neuron_idx
358358
]
359359
updated_dead_encoder_weights = parameter_updates[
@@ -363,7 +363,7 @@ def test_updates_dead_neuron_parameters(
363363
current_dead_encoder_weights, updated_dead_encoder_weights
364364
), "Dead encoder weights should have changed."
365365

366-
current_dead_encoder_bias = current_parameters["_encoder._bias"][
366+
current_dead_encoder_bias = current_parameters["encoder._bias"][
367367
component_idx, neuron_idx
368368
]
369369
updated_dead_encoder_bias = parameter_updates[component_idx].dead_encoder_bias_updates

0 commit comments

Comments
 (0)