|
16 | 16 | "# Quick Start Training Demo\n",
|
17 | 17 | "\n",
|
18 | 18 | "This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
|
19 |
| - "hyperparameters (like the model to train on), and then set it off.\n", |
20 |
| - "By default it replicates Neel Nanda's\n", |
21 |
| - "[comment on the Anthropic dictionary learning\n", |
22 |
| - "paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda)." |
| 19 | + "hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n", |
| 20 | + "MLP layers from GPT2 small." |
23 | 21 | ]
|
24 | 22 | },
|
25 | 23 | {
|
|
75 | 73 | " Method,\n",
|
76 | 74 | " OptimizerHyperparameters,\n",
|
77 | 75 | " Parameter,\n",
|
| 76 | + " PipelineHyperparameters,\n", |
78 | 77 | " SourceDataHyperparameters,\n",
|
79 | 78 | " SourceModelHyperparameters,\n",
|
80 | 79 | " sweep,\n",
|
|
103 | 102 | },
|
104 | 103 | {
|
105 | 104 | "cell_type": "code",
|
106 |
| - "execution_count": 3, |
| 105 | + "execution_count": 5, |
107 | 106 | "metadata": {},
|
108 | 107 | "outputs": [
|
109 | 108 | {
|
110 |
| - "ename": "TypeError", |
111 |
| - "evalue": "SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'", |
112 |
| - "output_type": "error", |
113 |
| - "traceback": [ |
114 |
| - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
115 |
| - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", |
116 |
| - "Cell \u001b[0;32mIn[3], line 12\u001b[0m\n\u001b[1;32m 1\u001b[0m sweep_config \u001b[38;5;241m=\u001b[39m SweepConfig(\n\u001b[1;32m 2\u001b[0m parameters\u001b[38;5;241m=\u001b[39mHyperparameters(\n\u001b[1;32m 3\u001b[0m activation_resampler\u001b[38;5;241m=\u001b[39mActivationResamplerHyperparameters(\n\u001b[1;32m 4\u001b[0m threshold_is_dead_portion_fires\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;241m1e-6\u001b[39m),\n\u001b[1;32m 5\u001b[0m ),\n\u001b[1;32m 6\u001b[0m loss\u001b[38;5;241m=\u001b[39mLossHyperparameters(\n\u001b[1;32m 7\u001b[0m l1_coefficient\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-2\u001b[39m, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4e-3\u001b[39m),\n\u001b[1;32m 8\u001b[0m ),\n\u001b[1;32m 9\u001b[0m optimizer\u001b[38;5;241m=\u001b[39mOptimizerHyperparameters(\n\u001b[1;32m 10\u001b[0m lr\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-3\u001b[39m, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m),\n\u001b[1;32m 11\u001b[0m ),\n\u001b[0;32m---> 12\u001b[0m source_model\u001b[38;5;241m=\u001b[39m\u001b[43mSourceModelHyperparameters\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgelu-2l\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmlp_out\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43mhook_layer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mhook_dimension\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m512\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 18\u001b[0m source_data\u001b[38;5;241m=\u001b[39mSourceDataHyperparameters(\n\u001b[1;32m 19\u001b[0m dataset_path\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeelNanda/c4-code-tokenized-2b\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 20\u001b[0m ),\n\u001b[1;32m 21\u001b[0m ),\n\u001b[1;32m 22\u001b[0m method\u001b[38;5;241m=\u001b[39mMethod\u001b[38;5;241m.\u001b[39mRANDOM,\n\u001b[1;32m 23\u001b[0m )\n\u001b[1;32m 24\u001b[0m sweep_config\n", |
117 |
| - "\u001b[0;31mTypeError\u001b[0m: SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'" |
118 |
| - ] |
| 109 | + "data": { |
| 110 | + "text/plain": [ |
| 111 | + "SweepConfig(parameters=Hyperparameters(\n", |
| 112 | + " source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n", |
| 113 | + " source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n", |
| 114 | + " activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n", |
| 115 | + " autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n", |
| 116 | + " loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n", |
| 117 | + " optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n", |
| 118 | + " pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n", |
| 119 | + " random_seed=Parameter(value=49)\n", |
| 120 | + "), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)" |
| 121 | + ] |
| 122 | + }, |
| 123 | + "execution_count": 5, |
| 124 | + "metadata": {}, |
| 125 | + "output_type": "execute_result" |
119 | 126 | }
|
120 | 127 | ],
|
121 | 128 | "source": [
|
| 129 | + "n_layers_gpt2_small = 12\n", |
| 130 | + "\n", |
122 | 131 | "sweep_config = SweepConfig(\n",
|
123 | 132 | " parameters=Hyperparameters(\n",
|
124 | 133 | " activation_resampler=ActivationResamplerHyperparameters(\n",
|
| 134 | + " resample_interval=Parameter(200_000_000),\n", |
| 135 | + " n_activations_activity_collate=Parameter(100_000_000),\n", |
125 | 136 | " threshold_is_dead_portion_fires=Parameter(1e-6),\n",
|
| 137 | + " max_n_resamples=Parameter(4),\n", |
| 138 | + " resample_dataset_size=Parameter(200_000),\n", |
126 | 139 | " ),\n",
|
127 | 140 | " loss=LossHyperparameters(\n",
|
128 | 141 | " l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
|
|
131 | 144 | " lr=Parameter(max=1e-3, min=1e-5),\n",
|
132 | 145 | " ),\n",
|
133 | 146 | " source_model=SourceModelHyperparameters(\n",
|
134 |
| - " name=Parameter(\"gelu-2l\"),\n", |
135 |
| - " cache_names=Parameter([\"blocks.0.hook_mlp_out\", \"blocks.1.hook_mlp_out\"]),\n", |
136 |
| - " hook_dimension=Parameter(512),\n", |
| 147 | + " name=Parameter(\"gpt2-small\"),\n", |
| 148 | + " # Train in parallel on all MLP layers\n", |
| 149 | + " cache_names=Parameter(\n", |
| 150 | + " [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n", |
| 151 | + " ),\n", |
| 152 | + " hook_dimension=Parameter(768),\n", |
137 | 153 | " ),\n",
|
138 | 154 | " source_data=SourceDataHyperparameters(\n",
|
139 |
| - " dataset_path=Parameter(\"NeelNanda/c4-code-tokenized-2b\"),\n", |
| 155 | + " dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n", |
| 156 | + " context_size=Parameter(128),\n", |
| 157 | + " pre_tokenized=Parameter(value=True),\n", |
| 158 | + " ),\n", |
| 159 | + " pipeline=PipelineHyperparameters(\n", |
| 160 | + " max_activations=Parameter(1_000_000_000),\n", |
| 161 | + " checkpoint_frequency=Parameter(100_000_000),\n", |
| 162 | + " validation_frequency=Parameter(100_000_000),\n", |
| 163 | + " train_batch_size=Parameter(1024),\n", |
| 164 | + " max_store_size=Parameter(300_000),\n", |
140 | 165 | " ),\n",
|
141 | 166 | " ),\n",
|
142 | 167 | " method=Method.RANDOM,\n",
|
|
0 commit comments