|
16 | 16 | },
|
17 | 17 | {
|
18 | 18 | "cell_type": "code",
|
19 |
| - "execution_count": 1, |
| 19 | + "execution_count": 5, |
20 | 20 | "metadata": {},
|
21 |
| - "outputs": [], |
| 21 | + "outputs": [ |
| 22 | + { |
| 23 | + "name": "stdout", |
| 24 | + "output_type": "stream", |
| 25 | + "text": [ |
| 26 | + "The autoreload extension is already loaded. To reload it, use:\n", |
| 27 | + " %reload_ext autoreload\n" |
| 28 | + ] |
| 29 | + } |
| 30 | + ], |
22 | 31 | "source": [
|
23 | 32 | "# Autoreload\n",
|
24 | 33 | "%load_ext autoreload\n",
|
|
27 | 36 | },
|
28 | 37 | {
|
29 | 38 | "cell_type": "code",
|
30 |
| - "execution_count": 2, |
| 39 | + "execution_count": 6, |
31 | 40 | "metadata": {},
|
32 | 41 | "outputs": [],
|
33 | 42 | "source": [
|
|
36 | 45 | "from transformer_lens import HookedTransformer\n",
|
37 | 46 | "from transformer_lens.utils import get_device\n",
|
38 | 47 | "from transformers import PreTrainedTokenizerBase\n",
|
39 |
| - "import torch\n", |
40 |
| - "import wandb" |
| 48 | + "import torch" |
41 | 49 | ]
|
42 | 50 | },
|
43 | 51 | {
|
44 | 52 | "cell_type": "code",
|
45 |
| - "execution_count": 3, |
| 53 | + "execution_count": 7, |
46 | 54 | "metadata": {},
|
47 | 55 | "outputs": [],
|
48 | 56 | "source": [
|
|
58 | 66 | },
|
59 | 67 | {
|
60 | 68 | "cell_type": "code",
|
61 |
| - "execution_count": 4, |
| 69 | + "execution_count": 9, |
62 | 70 | "metadata": {},
|
63 |
| - "outputs": [ |
64 |
| - { |
65 |
| - "name": "stdout", |
66 |
| - "output_type": "stream", |
67 |
| - "text": [ |
68 |
| - "Loaded pretrained model solu-1l into HookedTransformer\n" |
69 |
| - ] |
70 |
| - }, |
71 |
| - { |
72 |
| - "data": { |
73 |
| - "text/plain": [ |
74 |
| - "2048" |
75 |
| - ] |
76 |
| - }, |
77 |
| - "execution_count": 4, |
78 |
| - "metadata": {}, |
79 |
| - "output_type": "execute_result" |
80 |
| - } |
81 |
| - ], |
| 71 | + "outputs": [], |
82 | 72 | "source": [
|
83 | 73 | "src_model = HookedTransformer.from_pretrained(\"solu-1l\", dtype=\"float32\")\n",
|
84 | 74 | "src_d_mlp: int = src_model.cfg.d_mlp # type: ignore\n",
|
|
94 | 84 | },
|
95 | 85 | {
|
96 | 86 | "cell_type": "code",
|
97 |
| - "execution_count": 5, |
| 87 | + "execution_count": null, |
98 | 88 | "metadata": {},
|
99 |
| - "outputs": [ |
100 |
| - { |
101 |
| - "name": "stdout", |
102 |
| - "output_type": "stream", |
103 |
| - "text": [ |
104 |
| - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", |
105 |
| - "To disable this warning, you can either:\n", |
106 |
| - "\t- Avoid using `tokenizers` before the fork if possible\n", |
107 |
| - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" |
108 |
| - ] |
109 |
| - }, |
110 |
| - { |
111 |
| - "data": { |
112 |
| - "application/vnd.jupyter.widget-view+json": { |
113 |
| - "model_id": "a1ce590449484e1788109c4f13a2e8bf", |
114 |
| - "version_major": 2, |
115 |
| - "version_minor": 0 |
116 |
| - }, |
117 |
| - "text/plain": [ |
118 |
| - "Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]" |
119 |
| - ] |
120 |
| - }, |
121 |
| - "metadata": {}, |
122 |
| - "output_type": "display_data" |
123 |
| - } |
124 |
| - ], |
| 89 | + "outputs": [], |
125 | 90 | "source": [
|
126 | 91 | "tokenizer: PreTrainedTokenizerBase = src_model.tokenizer # type: ignore\n",
|
127 |
| - "source_data = PileUncopyrightedDataset(tokenizer=tokenizer)\n", |
128 |
| - "src_dataloader = source_data.get_dataloader(batch_size=8)" |
| 92 | + "source_data = PileUncopyrightedDataset(tokenizer=tokenizer)" |
129 | 93 | ]
|
130 | 94 | },
|
131 | 95 | {
|
|
137 | 101 | },
|
138 | 102 | {
|
139 | 103 | "cell_type": "code",
|
140 |
| - "execution_count": 6, |
| 104 | + "execution_count": null, |
141 | 105 | "metadata": {},
|
142 | 106 | "outputs": [],
|
143 | 107 | "source": [
|
|
154 | 118 | },
|
155 | 119 | {
|
156 | 120 | "cell_type": "code",
|
157 |
| - "execution_count": 7, |
| 121 | + "execution_count": null, |
158 | 122 | "metadata": {},
|
159 |
| - "outputs": [ |
160 |
| - { |
161 |
| - "data": { |
162 |
| - "text/plain": [ |
163 |
| - "SparseAutoencoder(\n", |
164 |
| - " (encoder): Sequential(\n", |
165 |
| - " (0): TiedBias(position=TiedBiasPosition.PRE_ENCODER)\n", |
166 |
| - " (1): ConstrainedUnitNormLinear(in_features=2048, out_features=16384, bias=True)\n", |
167 |
| - " (2): ReLU()\n", |
168 |
| - " )\n", |
169 |
| - " (decoder): Sequential(\n", |
170 |
| - " (0): ConstrainedUnitNormLinear(in_features=16384, out_features=2048, bias=False)\n", |
171 |
| - " (1): TiedBias(position=TiedBiasPosition.POST_DECODER)\n", |
172 |
| - " )\n", |
173 |
| - ")" |
174 |
| - ] |
175 |
| - }, |
176 |
| - "execution_count": 7, |
177 |
| - "metadata": {}, |
178 |
| - "output_type": "execute_result" |
179 |
| - } |
180 |
| - ], |
| 123 | + "outputs": [], |
181 | 124 | "source": [
|
182 | 125 | "autoencoder = SparseAutoencoder(src_d_mlp, src_d_mlp * 8, torch.zeros(src_d_mlp))\n",
|
183 | 126 | "autoencoder"
|
|
199 | 142 | },
|
200 | 143 | {
|
201 | 144 | "cell_type": "code",
|
202 |
| - "execution_count": 8, |
| 145 | + "execution_count": null, |
203 | 146 | "metadata": {},
|
204 | 147 | "outputs": [],
|
205 | 148 | "source": [
|
|
208 | 151 | },
|
209 | 152 | {
|
210 | 153 | "cell_type": "code",
|
211 |
| - "execution_count": 9, |
| 154 | + "execution_count": null, |
212 | 155 | "metadata": {},
|
213 |
| - "outputs": [ |
214 |
| - { |
215 |
| - "data": { |
216 |
| - "application/vnd.jupyter.widget-view+json": { |
217 |
| - "model_id": "309fbf4a29a147ada581ba09b0cff34d", |
218 |
| - "version_major": 2, |
219 |
| - "version_minor": 0 |
220 |
| - }, |
221 |
| - "text/plain": [ |
222 |
| - "Generate/Train Cycles: 0it [00:00, ?it/s]" |
223 |
| - ] |
224 |
| - }, |
225 |
| - "metadata": {}, |
226 |
| - "output_type": "display_data" |
227 |
| - }, |
228 |
| - { |
229 |
| - "data": { |
230 |
| - "application/vnd.jupyter.widget-view+json": { |
231 |
| - "model_id": "a26f99ac95d44bf196f1d5fe70bafbe9", |
232 |
| - "version_major": 2, |
233 |
| - "version_minor": 0 |
234 |
| - }, |
235 |
| - "text/plain": [ |
236 |
| - "Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]" |
237 |
| - ] |
238 |
| - }, |
239 |
| - "metadata": {}, |
240 |
| - "output_type": "display_data" |
241 |
| - }, |
242 |
| - { |
243 |
| - "data": { |
244 |
| - "application/vnd.jupyter.widget-view+json": { |
245 |
| - "model_id": "5cd07ef70a1f4b4c97cd2828f4cfd745", |
246 |
| - "version_major": 2, |
247 |
| - "version_minor": 0 |
248 |
| - }, |
249 |
| - "text/plain": [ |
250 |
| - "Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]" |
251 |
| - ] |
252 |
| - }, |
253 |
| - "metadata": {}, |
254 |
| - "output_type": "display_data" |
255 |
| - }, |
256 |
| - { |
257 |
| - "name": "stderr", |
258 |
| - "output_type": "stream", |
259 |
| - "text": [ |
260 |
| - "/Users/alan/Documents/Repos/sparse_autoencoder/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py:251: UserWarning: The operator 'aten::sgn.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n", |
261 |
| - " Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" |
262 |
| - ] |
263 |
| - }, |
264 |
| - { |
265 |
| - "data": { |
266 |
| - "application/vnd.jupyter.widget-view+json": { |
267 |
| - "model_id": "242a91de8f694f64a7d04e93f25b95dc", |
268 |
| - "version_major": 2, |
269 |
| - "version_minor": 0 |
270 |
| - }, |
271 |
| - "text/plain": [ |
272 |
| - "Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]" |
273 |
| - ] |
274 |
| - }, |
275 |
| - "metadata": {}, |
276 |
| - "output_type": "display_data" |
277 |
| - }, |
278 |
| - { |
279 |
| - "data": { |
280 |
| - "application/vnd.jupyter.widget-view+json": { |
281 |
| - "model_id": "5337eac728eb4ced9590c001e20e53ed", |
282 |
| - "version_major": 2, |
283 |
| - "version_minor": 0 |
284 |
| - }, |
285 |
| - "text/plain": [ |
286 |
| - "Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]" |
287 |
| - ] |
288 |
| - }, |
289 |
| - "metadata": {}, |
290 |
| - "output_type": "display_data" |
291 |
| - } |
292 |
| - ], |
| 156 | + "outputs": [], |
293 | 157 | "source": [
|
294 | 158 | "pipeline(\n",
|
295 | 159 | " src_model=src_model,\n",
|
296 | 160 | " src_model_activation_hook_point=\"blocks.0.mlp.hook_post\",\n",
|
297 | 161 | " src_model_activation_layer=0,\n",
|
298 |
| - " src_dataloader=src_dataloader,\n", |
| 162 | + " source_dataset=source_data,\n", |
299 | 163 | " activation_store=store,\n",
|
300 | 164 | " num_activations_before_training=max_items,\n",
|
301 | 165 | " autoencoder=autoencoder,\n",
|
|
0 commit comments