Skip to content

Commit 1e26c8f

Browse files
authored
Add stateful iterator to the pipeline (#55)
1 parent ae0ddb5 commit 1e26c8f

File tree

10 files changed

+121
-194
lines changed

10 files changed

+121
-194
lines changed

.vscode/cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"runcap",
7070
"sharded",
7171
"snapshottest",
72+
"solu",
7273
"tqdm",
7374
"transformer_lens",
7475
"typecheck",

demo.ipynb

Lines changed: 26 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,18 @@
1616
},
1717
{
1818
"cell_type": "code",
19-
"execution_count": 1,
19+
"execution_count": 5,
2020
"metadata": {},
21-
"outputs": [],
21+
"outputs": [
22+
{
23+
"name": "stdout",
24+
"output_type": "stream",
25+
"text": [
26+
"The autoreload extension is already loaded. To reload it, use:\n",
27+
" %reload_ext autoreload\n"
28+
]
29+
}
30+
],
2231
"source": [
2332
"# Autoreload\n",
2433
"%load_ext autoreload\n",
@@ -27,7 +36,7 @@
2736
},
2837
{
2938
"cell_type": "code",
30-
"execution_count": 2,
39+
"execution_count": 6,
3140
"metadata": {},
3241
"outputs": [],
3342
"source": [
@@ -36,13 +45,12 @@
3645
"from transformer_lens import HookedTransformer\n",
3746
"from transformer_lens.utils import get_device\n",
3847
"from transformers import PreTrainedTokenizerBase\n",
39-
"import torch\n",
40-
"import wandb"
48+
"import torch"
4149
]
4250
},
4351
{
4452
"cell_type": "code",
45-
"execution_count": 3,
53+
"execution_count": 7,
4654
"metadata": {},
4755
"outputs": [],
4856
"source": [
@@ -58,27 +66,9 @@
5866
},
5967
{
6068
"cell_type": "code",
61-
"execution_count": 4,
69+
"execution_count": 9,
6270
"metadata": {},
63-
"outputs": [
64-
{
65-
"name": "stdout",
66-
"output_type": "stream",
67-
"text": [
68-
"Loaded pretrained model solu-1l into HookedTransformer\n"
69-
]
70-
},
71-
{
72-
"data": {
73-
"text/plain": [
74-
"2048"
75-
]
76-
},
77-
"execution_count": 4,
78-
"metadata": {},
79-
"output_type": "execute_result"
80-
}
81-
],
71+
"outputs": [],
8272
"source": [
8373
"src_model = HookedTransformer.from_pretrained(\"solu-1l\", dtype=\"float32\")\n",
8474
"src_d_mlp: int = src_model.cfg.d_mlp # type: ignore\n",
@@ -94,38 +84,12 @@
9484
},
9585
{
9686
"cell_type": "code",
97-
"execution_count": 5,
87+
"execution_count": null,
9888
"metadata": {},
99-
"outputs": [
100-
{
101-
"name": "stdout",
102-
"output_type": "stream",
103-
"text": [
104-
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
105-
"To disable this warning, you can either:\n",
106-
"\t- Avoid using `tokenizers` before the fork if possible\n",
107-
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
108-
]
109-
},
110-
{
111-
"data": {
112-
"application/vnd.jupyter.widget-view+json": {
113-
"model_id": "a1ce590449484e1788109c4f13a2e8bf",
114-
"version_major": 2,
115-
"version_minor": 0
116-
},
117-
"text/plain": [
118-
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
119-
]
120-
},
121-
"metadata": {},
122-
"output_type": "display_data"
123-
}
124-
],
89+
"outputs": [],
12590
"source": [
12691
"tokenizer: PreTrainedTokenizerBase = src_model.tokenizer # type: ignore\n",
127-
"source_data = PileUncopyrightedDataset(tokenizer=tokenizer)\n",
128-
"src_dataloader = source_data.get_dataloader(batch_size=8)"
92+
"source_data = PileUncopyrightedDataset(tokenizer=tokenizer)"
12993
]
13094
},
13195
{
@@ -137,7 +101,7 @@
137101
},
138102
{
139103
"cell_type": "code",
140-
"execution_count": 6,
104+
"execution_count": null,
141105
"metadata": {},
142106
"outputs": [],
143107
"source": [
@@ -154,30 +118,9 @@
154118
},
155119
{
156120
"cell_type": "code",
157-
"execution_count": 7,
121+
"execution_count": null,
158122
"metadata": {},
159-
"outputs": [
160-
{
161-
"data": {
162-
"text/plain": [
163-
"SparseAutoencoder(\n",
164-
" (encoder): Sequential(\n",
165-
" (0): TiedBias(position=TiedBiasPosition.PRE_ENCODER)\n",
166-
" (1): ConstrainedUnitNormLinear(in_features=2048, out_features=16384, bias=True)\n",
167-
" (2): ReLU()\n",
168-
" )\n",
169-
" (decoder): Sequential(\n",
170-
" (0): ConstrainedUnitNormLinear(in_features=16384, out_features=2048, bias=False)\n",
171-
" (1): TiedBias(position=TiedBiasPosition.POST_DECODER)\n",
172-
" )\n",
173-
")"
174-
]
175-
},
176-
"execution_count": 7,
177-
"metadata": {},
178-
"output_type": "execute_result"
179-
}
180-
],
123+
"outputs": [],
181124
"source": [
182125
"autoencoder = SparseAutoencoder(src_d_mlp, src_d_mlp * 8, torch.zeros(src_d_mlp))\n",
183126
"autoencoder"
@@ -199,7 +142,7 @@
199142
},
200143
{
201144
"cell_type": "code",
202-
"execution_count": 8,
145+
"execution_count": null,
203146
"metadata": {},
204147
"outputs": [],
205148
"source": [
@@ -208,94 +151,15 @@
208151
},
209152
{
210153
"cell_type": "code",
211-
"execution_count": 9,
154+
"execution_count": null,
212155
"metadata": {},
213-
"outputs": [
214-
{
215-
"data": {
216-
"application/vnd.jupyter.widget-view+json": {
217-
"model_id": "309fbf4a29a147ada581ba09b0cff34d",
218-
"version_major": 2,
219-
"version_minor": 0
220-
},
221-
"text/plain": [
222-
"Generate/Train Cycles: 0it [00:00, ?it/s]"
223-
]
224-
},
225-
"metadata": {},
226-
"output_type": "display_data"
227-
},
228-
{
229-
"data": {
230-
"application/vnd.jupyter.widget-view+json": {
231-
"model_id": "a26f99ac95d44bf196f1d5fe70bafbe9",
232-
"version_major": 2,
233-
"version_minor": 0
234-
},
235-
"text/plain": [
236-
"Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]"
237-
]
238-
},
239-
"metadata": {},
240-
"output_type": "display_data"
241-
},
242-
{
243-
"data": {
244-
"application/vnd.jupyter.widget-view+json": {
245-
"model_id": "5cd07ef70a1f4b4c97cd2828f4cfd745",
246-
"version_major": 2,
247-
"version_minor": 0
248-
},
249-
"text/plain": [
250-
"Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]"
251-
]
252-
},
253-
"metadata": {},
254-
"output_type": "display_data"
255-
},
256-
{
257-
"name": "stderr",
258-
"output_type": "stream",
259-
"text": [
260-
"/Users/alan/Documents/Repos/sparse_autoencoder/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py:251: UserWarning: The operator 'aten::sgn.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
261-
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
262-
]
263-
},
264-
{
265-
"data": {
266-
"application/vnd.jupyter.widget-view+json": {
267-
"model_id": "242a91de8f694f64a7d04e93f25b95dc",
268-
"version_major": 2,
269-
"version_minor": 0
270-
},
271-
"text/plain": [
272-
"Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]"
273-
]
274-
},
275-
"metadata": {},
276-
"output_type": "display_data"
277-
},
278-
{
279-
"data": {
280-
"application/vnd.jupyter.widget-view+json": {
281-
"model_id": "5337eac728eb4ced9590c001e20e53ed",
282-
"version_major": 2,
283-
"version_minor": 0
284-
},
285-
"text/plain": [
286-
"Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]"
287-
]
288-
},
289-
"metadata": {},
290-
"output_type": "display_data"
291-
}
292-
],
156+
"outputs": [],
293157
"source": [
294158
"pipeline(\n",
295159
" src_model=src_model,\n",
296160
" src_model_activation_hook_point=\"blocks.0.mlp.hook_post\",\n",
297161
" src_model_activation_layer=0,\n",
298-
" src_dataloader=src_dataloader,\n",
162+
" source_dataset=source_data,\n",
299163
" activation_store=store,\n",
300164
" num_activations_before_training=max_items,\n",
301165
" autoencoder=autoencoder,\n",

sparse_autoencoder/source_data/abstract_dataset.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Any, Generic, TypedDict, TypeVar, final
44

55
from datasets import IterableDataset, load_dataset
6+
from jaxtyping import Int
7+
from torch import Tensor
68
from torch.utils.data import DataLoader
79
from torch.utils.data import Dataset as TorchDataset
810

@@ -11,12 +13,18 @@
1113
"""A tokenized prompt."""
1214

1315

14-
class PreprocessTokenizedPrompts(TypedDict):
15-
"""Preprocess tokenized prompts return type."""
16+
class TokenizedPrompts(TypedDict):
17+
"""Tokenized prompts."""
1618

1719
input_ids: list[TokenizedPrompt]
1820

1921

22+
class TorchTokenizedPrompts(TypedDict):
23+
"""Tokenized prompts prepared for PyTorch."""
24+
25+
input_ids: Int[Tensor, "batch pos"]
26+
27+
2028
HuggingFaceDatasetItem = TypeVar("HuggingFaceDatasetItem", bound=Any)
2129
"""Hugging face dataset item typed dict.
2230
@@ -65,7 +73,7 @@ def preprocess(
6573
source_batch: HuggingFaceDatasetItem,
6674
*,
6775
context_size: int,
68-
) -> PreprocessTokenizedPrompts:
76+
) -> TokenizedPrompts:
6977
"""Preprocess function.
7078
7179
Takes a `preprocess_batch_size` ($m$) batch of source data (which may e.g. include string
@@ -119,6 +127,8 @@ def __init__(
119127
preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
120128
tokenizing prompts).
121129
"""
130+
self.context_size = context_size
131+
122132
# Load the dataset
123133
dataset: IterableDataset = load_dataset(dataset_path, streaming=True, split=dataset_split) # type: ignore
124134

@@ -154,7 +164,7 @@ def __next__(self) -> Any: # noqa: ANN401
154164
return next(iter(self))
155165

156166
@final
157-
def get_dataloader(self, batch_size: int) -> DataLoader:
167+
def get_dataloader(self, batch_size: int) -> DataLoader[TorchTokenizedPrompts]:
158168
"""Get a PyTorch DataLoader.
159169
160170
Args:
@@ -163,9 +173,9 @@ def get_dataloader(self, batch_size: int) -> DataLoader:
163173
Returns:
164174
PyTorch DataLoader.
165175
"""
166-
torch_dataset: TorchDataset = self.dataset.with_format("torch") # type: ignore
176+
torch_dataset: TorchDataset[TorchTokenizedPrompts] = self.dataset.with_format("torch") # type: ignore
167177

168-
return DataLoader(
178+
return DataLoader[TorchTokenizedPrompts](
169179
torch_dataset,
170180
batch_size=batch_size,
171181
# Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not

sparse_autoencoder/source_data/c4_pre_tokenized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import TypedDict, final
88

99
from sparse_autoencoder.source_data.abstract_dataset import (
10-
PreprocessTokenizedPrompts,
1110
SourceDataset,
11+
TokenizedPrompts,
1212
)
1313

1414

@@ -33,7 +33,7 @@ def preprocess(
3333
source_batch: NeelC4SourceDataBatch,
3434
*,
3535
context_size: int,
36-
) -> PreprocessTokenizedPrompts:
36+
) -> TokenizedPrompts:
3737
"""Preprocess a batch of prompts.
3838
3939
As this dataset is already tokenized, all this does is split up each item based on the

sparse_autoencoder/source_data/pile_uncopyrighted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from transformers import PreTrainedTokenizerBase
55

66
from sparse_autoencoder.source_data.abstract_dataset import (
7-
PreprocessTokenizedPrompts,
87
SourceDataset,
8+
TokenizedPrompts,
99
)
1010

1111

@@ -33,7 +33,7 @@ def preprocess(
3333
source_batch: PileUncopyrightedSourceDataBatch,
3434
*,
3535
context_size: int,
36-
) -> PreprocessTokenizedPrompts:
36+
) -> TokenizedPrompts:
3737
"""Preprocess a batch of prompts.
3838
3939
For each prompt's `text`, tokenize it and chunk into a list of tokenized prompts of length

0 commit comments

Comments
 (0)