|
107 | 107 | " # and we have found that 4x is a good starting point.\n",
|
108 | 108 | " \"expansion_factor\": 4,\n",
|
109 | 109 | " # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).\n",
|
110 |
| - " \"l1_coefficient\": 3e-4,\n", |
| 110 | + " \"l1_coefficient\": 1e-3,\n", |
111 | 111 | " # Adam parameters (set to the default ones here)\n",
|
112 |
| - " \"lr\": 1e-4,\n", |
| 112 | + " \"lr\": 3e-4,\n", |
113 | 113 | " \"adam_beta_1\": 0.9,\n",
|
114 | 114 | " \"adam_beta_2\": 0.999,\n",
|
115 | 115 | " \"adam_epsilon\": 1e-8,\n",
|
116 | 116 | " \"adam_weight_decay\": 0.0,\n",
|
117 | 117 | " # Batch sizes\n",
|
118 | 118 | " \"train_batch_size\": 4096,\n",
|
119 | 119 | " \"context_size\": 128,\n",
|
| 120 | + " # Source model hook point\n", |
| 121 | + " \"source_model_name\": \"gelu-2l\",\n", |
| 122 | + " \"source_model_dtype\": \"float32\",\n", |
| 123 | + " \"source_model_hook_point\": \"blocks.0.hook_mlp_out\",\n", |
| 124 | + " \"source_model_hook_point_layer\": 0,\n", |
| 125 | + " # Train pipeline parameters\n", |
| 126 | + " \"max_store_size\": 384 * 4096 * 2,\n", |
| 127 | + " \"max_activations\": 2_000_000_000,\n", |
| 128 | + " \"resample_frequency\": 122_880_000,\n", |
| 129 | + " \"checkpoint_frequency\": 100_000_000,\n", |
| 130 | + " \"validation_frequency\": 384 * 4096 * 2 * 100, # Every 100 generations\n", |
120 | 131 | "}"
|
121 | 132 | ]
|
122 | 133 | },
|
|
141 | 152 | },
|
142 | 153 | {
|
143 | 154 | "cell_type": "code",
|
144 |
| - "execution_count": 3, |
| 155 | + "execution_count": 4, |
145 | 156 | "metadata": {},
|
146 | 157 | "outputs": [
|
147 | 158 | {
|
|
157 | 168 | "'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'"
|
158 | 169 | ]
|
159 | 170 | },
|
160 |
| - "execution_count": 3, |
| 171 | + "execution_count": 4, |
161 | 172 | "metadata": {},
|
162 | 173 | "output_type": "execute_result"
|
163 | 174 | }
|
164 | 175 | ],
|
165 | 176 | "source": [
|
166 | 177 | "# Source model setup with TransformerLens\n",
|
167 |
| - "src_model_name = \"gelu-2l\"\n", |
168 |
| - "src_model = HookedTransformer.from_pretrained(src_model_name, dtype=\"float32\")\n", |
| 178 | + "src_model = HookedTransformer.from_pretrained(\n", |
| 179 | + " str(hyperparameters[\"source_model_name\"]), dtype=str(hyperparameters[\"source_model_dtype\"])\n", |
| 180 | + ")\n", |
169 | 181 | "\n",
|
170 | 182 | "# Details about the activations we'll train the sparse autoencoder on\n",
|
171 |
| - "src_model_activation_hook_point = \"blocks.0.hook_mlp_out\"\n", |
172 |
| - "src_model_activation_layer = 0\n", |
173 | 183 | "autoencoder_input_dim: int = src_model.cfg.d_model # type: ignore (TransformerLens typing is currently broken)\n",
|
174 | 184 | "\n",
|
175 |
| - "f\"Source: {src_model_name}, Hook: {src_model_activation_hook_point}, \\\n", |
| 185 | + "f\"Source: {hyperparameters['source_model_name']}, \\\n", |
| 186 | + " Hook: {hyperparameters['source_model_hook_point']}, \\\n", |
176 | 187 | " Features: {autoencoder_input_dim}\""
|
177 | 188 | ]
|
178 | 189 | },
|
|
199 | 210 | },
|
200 | 211 | {
|
201 | 212 | "cell_type": "code",
|
202 |
| - "execution_count": 4, |
| 213 | + "execution_count": 5, |
203 | 214 | "metadata": {},
|
204 | 215 | "outputs": [
|
205 | 216 | {
|
|
216 | 227 | ")"
|
217 | 228 | ]
|
218 | 229 | },
|
219 |
| - "execution_count": 4, |
| 230 | + "execution_count": 5, |
220 | 231 | "metadata": {},
|
221 | 232 | "output_type": "execute_result"
|
222 | 233 | }
|
|
244 | 255 | },
|
245 | 256 | {
|
246 | 257 | "cell_type": "code",
|
247 |
| - "execution_count": 5, |
| 258 | + "execution_count": 6, |
248 | 259 | "metadata": {},
|
249 | 260 | "outputs": [
|
250 | 261 | {
|
251 | 262 | "data": {
|
252 | 263 | "text/plain": [
|
253 | 264 | "LossReducer(\n",
|
254 |
| - " (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n", |
| 265 | + " (0): LearnedActivationsL1Loss(l1_coefficient=0.001)\n", |
255 | 266 | " (1): L2ReconstructionLoss()\n",
|
256 | 267 | ")"
|
257 | 268 | ]
|
258 | 269 | },
|
259 |
| - "execution_count": 5, |
| 270 | + "execution_count": 6, |
260 | 271 | "metadata": {},
|
261 | 272 | "output_type": "execute_result"
|
262 | 273 | }
|
|
265 | 276 | "# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n",
|
266 | 277 | "loss = LossReducer(\n",
|
267 | 278 | " LearnedActivationsL1Loss(\n",
|
268 |
| - " l1_coefficient=hyperparameters[\"l1_coefficient\"],\n", |
| 279 | + " l1_coefficient=float(hyperparameters[\"l1_coefficient\"]),\n", |
269 | 280 | " ),\n",
|
270 | 281 | " L2ReconstructionLoss(),\n",
|
271 | 282 | ")\n",
|
|
274 | 285 | },
|
275 | 286 | {
|
276 | 287 | "cell_type": "code",
|
277 |
| - "execution_count": 6, |
| 288 | + "execution_count": 7, |
278 | 289 | "metadata": {},
|
279 | 290 | "outputs": [
|
280 | 291 | {
|
|
289 | 300 | " eps: 1e-08\n",
|
290 | 301 | " foreach: None\n",
|
291 | 302 | " fused: None\n",
|
292 |
| - " lr: 0.0001\n", |
| 303 | + " lr: 0.0003\n", |
293 | 304 | " maximize: False\n",
|
294 | 305 | " weight_decay: 0.0\n",
|
295 | 306 | ")"
|
296 | 307 | ]
|
297 | 308 | },
|
298 |
| - "execution_count": 6, |
| 309 | + "execution_count": 7, |
299 | 310 | "metadata": {},
|
300 | 311 | "output_type": "execute_result"
|
301 | 312 | }
|
|
304 | 315 | "optimizer = AdamWithReset(\n",
|
305 | 316 | " params=autoencoder.parameters(),\n",
|
306 | 317 | " named_parameters=autoencoder.named_parameters(),\n",
|
307 |
| - " lr=hyperparameters[\"lr\"],\n", |
308 |
| - " betas=(hyperparameters[\"adam_beta_1\"], hyperparameters[\"adam_beta_2\"]),\n", |
309 |
| - " eps=hyperparameters[\"adam_epsilon\"],\n", |
310 |
| - " weight_decay=hyperparameters[\"adam_weight_decay\"],\n", |
| 318 | + " lr=float(hyperparameters[\"lr\"]),\n", |
| 319 | + " betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n", |
| 320 | + " eps=float(hyperparameters[\"adam_epsilon\"]),\n", |
| 321 | + " weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n", |
311 | 322 | ")\n",
|
312 | 323 | "optimizer"
|
313 | 324 | ]
|
|
321 | 332 | },
|
322 | 333 | {
|
323 | 334 | "cell_type": "code",
|
324 |
| - "execution_count": 7, |
| 335 | + "execution_count": 8, |
325 | 336 | "metadata": {},
|
326 | 337 | "outputs": [],
|
327 | 338 | "source": [
|
|
345 | 356 | },
|
346 | 357 | {
|
347 | 358 | "cell_type": "code",
|
348 |
| - "execution_count": 8, |
| 359 | + "execution_count": 9, |
349 | 360 | "metadata": {},
|
350 | 361 | "outputs": [
|
351 | 362 | {
|
352 | 363 | "data": {
|
353 | 364 | "application/vnd.jupyter.widget-view+json": {
|
354 |
| - "model_id": "4bdf3ebe364243bd8f881933e56c997d", |
| 365 | + "model_id": "75e636ebb9e04b279c7216c74496538d", |
355 | 366 | "version_major": 2,
|
356 | 367 | "version_minor": 0
|
357 | 368 | },
|
|
390 | 401 | },
|
391 | 402 | {
|
392 | 403 | "cell_type": "code",
|
393 |
| - "execution_count": 9, |
| 404 | + "execution_count": 10, |
394 | 405 | "metadata": {},
|
395 | 406 | "outputs": [],
|
396 | 407 | "source": [
|
|
400 | 411 | },
|
401 | 412 | {
|
402 | 413 | "cell_type": "code",
|
403 |
| - "execution_count": 10, |
| 414 | + "execution_count": 11, |
404 | 415 | "metadata": {},
|
405 | 416 | "outputs": [
|
406 | 417 | {
|
|
426 | 437 | {
|
427 | 438 | "data": {
|
428 | 439 | "text/html": [
|
429 |
| - "Run data is saved locally in <code>.cache/wandb/run-20231126_122954-xsruek7y</code>" |
| 440 | + "Run data is saved locally in <code>.cache/wandb/run-20231126_184500-2fnpg8zi</code>" |
430 | 441 | ],
|
431 | 442 | "text/plain": [
|
432 | 443 | "<IPython.core.display.HTML object>"
|
|
438 | 449 | {
|
439 | 450 | "data": {
|
440 | 451 | "text/html": [
|
441 |
| - "Syncing run <strong><a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y' target=\"_blank\">vivid-totem-95</a></strong> to <a href='https://wandb.ai/alan-cooney/sparse-autoencoder' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" |
| 452 | + "Syncing run <strong><a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi' target=\"_blank\">prime-star-105</a></strong> to <a href='https://wandb.ai/alan-cooney/sparse-autoencoder' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" |
442 | 453 | ],
|
443 | 454 | "text/plain": [
|
444 | 455 | "<IPython.core.display.HTML object>"
|
|
462 | 473 | {
|
463 | 474 | "data": {
|
464 | 475 | "text/html": [
|
465 |
| - " View run at <a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y' target=\"_blank\">https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y</a>" |
| 476 | + " View run at <a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi' target=\"_blank\">https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi</a>" |
466 | 477 | ],
|
467 | 478 | "text/plain": [
|
468 | 479 | "<IPython.core.display.HTML object>"
|
|
474 | 485 | {
|
475 | 486 | "data": {
|
476 | 487 | "text/html": [
|
477 |
| - "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>" |
| 488 | + "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>" |
478 | 489 | ],
|
479 | 490 | "text/plain": [
|
480 |
| - "<wandb.sdk.wandb_run.Run at 0x2ff1cbcd0>" |
| 491 | + "<wandb.sdk.wandb_run.Run at 0x3154cec10>" |
481 | 492 | ]
|
482 | 493 | },
|
483 |
| - "execution_count": 10, |
| 494 | + "execution_count": 11, |
484 | 495 | "metadata": {},
|
485 | 496 | "output_type": "execute_result"
|
486 | 497 | }
|
|
496 | 507 | },
|
497 | 508 | {
|
498 | 509 | "cell_type": "code",
|
499 |
| - "execution_count": 11, |
| 510 | + "execution_count": 12, |
500 | 511 | "metadata": {},
|
501 | 512 | "outputs": [
|
502 | 513 | {
|
503 | 514 | "data": {
|
504 | 515 | "application/vnd.jupyter.widget-view+json": {
|
505 |
| - "model_id": "e1e6fa019f524f3da19708a4eda9b349", |
| 516 | + "model_id": "1322f5e5dd5c4507a6eca9aa1f010882", |
506 | 517 | "version_major": 2,
|
507 | 518 | "version_minor": 0
|
508 | 519 | },
|
|
526 | 537 | "pipeline = Pipeline(\n",
|
527 | 538 | " activation_resampler=activation_resampler,\n",
|
528 | 539 | " autoencoder=autoencoder,\n",
|
529 |
| - " cache_name=src_model_activation_hook_point,\n", |
| 540 | + " cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n", |
530 | 541 | " checkpoint_directory=checkpoint_path,\n",
|
531 |
| - " layer=src_model_activation_layer,\n", |
| 542 | + " layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n", |
532 | 543 | " loss=loss,\n",
|
533 | 544 | " optimizer=optimizer,\n",
|
534 | 545 | " source_data_batch_size=6,\n",
|
|
538 | 549 | "\n",
|
539 | 550 | "pipeline.run_pipeline(\n",
|
540 | 551 | " train_batch_size=int(hyperparameters[\"train_batch_size\"]),\n",
|
541 |
| - " max_store_size=384 * 4096 * 2,\n", |
542 |
| - " # Sizes for demo purposes (you probably want to scale these by 10x)\n", |
543 |
| - " max_activations=2_000_000_000,\n", |
544 |
| - " resample_frequency=122_880_000,\n", |
545 |
| - " checkpoint_frequency=100_000_000,\n", |
| 552 | + " max_store_size=int(hyperparameters[\"max_store_size\"]),\n", |
| 553 | + " max_activations=int(hyperparameters[\"max_activations\"]),\n", |
| 554 | + " resample_frequency=int(hyperparameters[\"resample_frequency\"]),\n", |
| 555 | + " checkpoint_frequency=int(hyperparameters[\"checkpoint_frequency\"]),\n", |
| 556 | + " validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n", |
546 | 557 | ")"
|
547 | 558 | ]
|
548 | 559 | },
|
|
554 | 565 | "source": [
|
555 | 566 | "wandb.finish()"
|
556 | 567 | ]
|
557 |
| - }, |
558 |
| - { |
559 |
| - "cell_type": "markdown", |
560 |
| - "metadata": {}, |
561 |
| - "source": [ |
562 |
| - "## Training Advice\n", |
563 |
| - "\n", |
564 |
| - "-- Unfinished --\n", |
565 |
| - "\n", |
566 |
| - "- Check recovery loss is low while sparsity is low as well (<20 L1) usually.\n", |
567 |
| - "- Can't be sure features are useful until you dig into them more. " |
568 |
| - ] |
569 | 568 | }
|
570 | 569 | ],
|
571 | 570 | "metadata": {
|
|
0 commit comments