Skip to content

Commit 8f97036

Browse files
authored
Support text datasets with sweep (#151)
1 parent 06a022c commit 8f97036

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

sparse_autoencoder/train/sweep.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from transformer_lens import HookedTransformer
88
from transformer_lens.utils import get_act_name, get_device
9+
from transformers import AutoTokenizer
910
import wandb
1011

1112
from sparse_autoencoder import (
@@ -18,6 +19,8 @@
1819
PreTokenizedDataset,
1920
SparseAutoencoder,
2021
)
22+
from sparse_autoencoder.source_data.abstract_dataset import SourceDataset
23+
from sparse_autoencoder.source_data.text_dataset import TextDataset
2124
from sparse_autoencoder.train.sweep_config import (
2225
RuntimeHyperparameters,
2326
SweepConfig,
@@ -126,18 +129,41 @@ def setup_optimizer(
126129
)
127130

128131

129-
def setup_source_data(hyperparameters: RuntimeHyperparameters) -> PreTokenizedDataset:
132+
def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:
130133
"""Setup the source data for training.
131134
132135
Args:
133136
hyperparameters: The hyperparameters dictionary.
134137
135138
Returns:
136-
PreTokenizedDataset: The initialized source data.
139+
The initialized source dataset.
140+
141+
Raises:
142+
ValueError: If the tokenizer name is not specified, but pre_tokenized is False.
137143
"""
138-
return PreTokenizedDataset(
144+
if hyperparameters["source_data"]["pre_tokenized"]:
145+
return PreTokenizedDataset(
146+
dataset_path=hyperparameters["source_data"]["dataset_path"],
147+
context_size=hyperparameters["source_data"]["context_size"],
148+
dataset_dir=hyperparameters["source_data"]["dataset_dir"],
149+
dataset_files=hyperparameters["source_data"]["dataset_files"],
150+
)
151+
152+
if hyperparameters["source_data"]["tokenizer_name"] is None:
153+
error_message = (
154+
"If pre_tokenized is False, then tokenizer_name must be specified in the "
155+
"hyperparameters."
156+
)
157+
raise ValueError(error_message)
158+
159+
tokenizer = AutoTokenizer.from_pretrained(hyperparameters["source_data"]["tokenizer_name"])
160+
161+
return TextDataset(
139162
dataset_path=hyperparameters["source_data"]["dataset_path"],
140163
context_size=hyperparameters["source_data"]["context_size"],
164+
tokenizer=tokenizer,
165+
dataset_dir=hyperparameters["source_data"]["dataset_dir"],
166+
dataset_files=hyperparameters["source_data"]["dataset_files"],
141167
)
142168

143169

@@ -154,7 +180,7 @@ def run_training_pipeline(
154180
loss: LossReducer,
155181
optimizer: AdamWithReset,
156182
activation_resampler: ActivationResampler,
157-
source_data: PreTokenizedDataset,
183+
source_data: SourceDataset,
158184
run_name: str,
159185
) -> None:
160186
"""Run the training pipeline for the sparse autoencoder.

sparse_autoencoder/train/sweep_config.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,45 @@ class SourceDataHyperparameters(NestedParameter):
174174
context_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_CONTEXT_SIZE))
175175
"""Context size."""
176176

177+
dataset_dir: Parameter[str | None] = field(default=Parameter(None))
178+
"""Dataset directory (within the HF dataset)"""
179+
180+
dataset_files: Parameter[str | None] = field(default=Parameter(None))
181+
"""Dataset files (within the HF dataset)."""
182+
183+
pre_tokenized: Parameter[bool] = field(default=Parameter(value=True))
184+
"""If the dataset is pre-tokenized."""
185+
186+
tokenizer_name: Parameter[str | None] = field(default=Parameter(None))
187+
"""Tokenizer name.
188+
189+
Only set this if the dataset is not pre-tokenized.
190+
"""
191+
192+
def __post_init__(self) -> None:
193+
"""Post initialisation checks.
194+
195+
Raises:
196+
ValueError: If there is an error in the source data hyperparameters.
197+
"""
198+
if self.pre_tokenized.value is False and self.tokenizer_name.value is None:
199+
error_message = "The tokenizer name must be specified, when `pre_tokenized` is False."
200+
raise ValueError(error_message)
201+
202+
if self.pre_tokenized.value is True and self.tokenizer_name.value is not None:
203+
error_message = "The tokenizer name must not be set, when `pre_tokenized` is True."
204+
raise ValueError(error_message)
205+
177206

178207
class SourceDataRuntimeHyperparameters(TypedDict):
179208
"""Source data runtime hyperparameters."""
180209

181-
dataset_path: str
182210
context_size: int
211+
dataset_dir: str | None
212+
dataset_files: str | None
213+
dataset_path: str
214+
pre_tokenized: bool
215+
tokenizer_name: str | None
183216

184217

185218
@dataclass(frozen=True)

sparse_autoencoder/train/tests/test_sweep.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,14 @@ def dummy_hyperparameters() -> RuntimeHyperparameters:
4646
"validation_number_activations": 1024,
4747
},
4848
"random_seed": 49,
49-
"source_data": {"context_size": 128, "dataset_path": "NeelNanda/c4-code-tokenized-2b"},
49+
"source_data": {
50+
"context_size": 128,
51+
"dataset_dir": None,
52+
"dataset_files": None,
53+
"dataset_path": "NeelNanda/c4-code-tokenized-2b",
54+
"pre_tokenized": True,
55+
"tokenizer_name": None,
56+
},
5057
"source_model": {
5158
"dtype": "float32",
5259
"hook_dimension": 512,

sparse_autoencoder/train/utils/wandb_sweep_types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC
66
from dataclasses import asdict, dataclass, is_dataclass
77
from enum import Enum, auto
8-
from typing import Any, Generic, TypeVar, final
8+
from typing import Any, Generic, TypeAlias, TypeVar, final
99

1010
from strenum import LowercaseStrEnum
1111

@@ -264,7 +264,11 @@ def __repr__(self) -> str:
264264
return self.__str__()
265265

266266

267-
ParamType = TypeVar("ParamType", float, int, str)
267+
OptionalFloat: TypeAlias = float | None
268+
OptionalInt: TypeAlias = int | None
269+
OptionalStr: TypeAlias = str | None
270+
271+
ParamType = TypeVar("ParamType", float, int, str, OptionalFloat, OptionalInt, OptionalStr)
268272

269273

270274
@dataclass(frozen=True)

0 commit comments

Comments
 (0)