Skip to content

Commit 05ed6b3

Browse files
authored
Add multiprocessing support to datasets (#150)
Useful where the data is sharded across files, as multiple files can be processed in parallel (when pre-downloaded).
1 parent 8f97036 commit 05ed6b3

File tree

8 files changed

+67
-38
lines changed

8 files changed

+67
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@
180180
convention="google"
181181

182182
[tool.ruff.lint.pylint]
183-
max-args=10
183+
max-args=15
184184

185185
[tool.pyright]
186186
# All rules apart from base are shown explicitly below

sparse_autoencoder/optimizer/adam_with_reset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class AdamWithReset(Adam, AbstractOptimizerWithReset):
3333
The names of the parameters, so that we can find them later when resetting the state.
3434
"""
3535

36-
def __init__( # noqa: PLR0913 (extending existing implementation)
36+
def __init__( # (extending existing implementation)
3737
self,
3838
params: params_t,
3939
lr: float | Tensor = 1e-3,

sparse_autoencoder/source_data/abstract_dataset.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
buffer_size: int = 1000,
112112
dataset_dir: str | None = None,
113113
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
114+
n_processes_preprocessing: int | None = None,
114115
preprocess_batch_size: int = 1000,
115116
*,
116117
pre_download: bool = False,
@@ -135,6 +136,7 @@ def __init__(
135136
tokenized prompts once the preprocessing function has been applied.
136137
dataset_dir: Defining the `data_dir` of the dataset configuration.
137138
dataset_files: Path(s) to source data file(s).
139+
n_processes_preprocessing: The number of processes to use for preprocessing.
138140
preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
139141
tokenizing prompts).
140142
pre_download: Whether to pre-download the whole dataset.
@@ -146,43 +148,53 @@ def __init__(
146148

147149
# Load the dataset
148150
should_stream = not pre_download
149-
loaded_dataset = load_dataset(
151+
dataset = load_dataset(
150152
dataset_path,
151153
streaming=should_stream,
152154
split=dataset_split,
153155
data_dir=dataset_dir,
154156
data_files=dataset_files,
155157
)
156158

157-
# Check the dataset is a Hugging Face Dataset or IterableDataset
158-
if not isinstance(loaded_dataset, Dataset) and not isinstance(
159-
loaded_dataset, IterableDataset
160-
):
161-
error_message = (
162-
f"Expected Hugging Face dataset to be a Dataset or IterableDataset, but got "
163-
f"{type(loaded_dataset)}."
164-
)
165-
raise TypeError(error_message)
166-
167-
dataset: Dataset | IterableDataset = loaded_dataset
168-
169159
# Setup preprocessing
170160
existing_columns: list[str] = list(next(iter(dataset)).keys())
171-
mapped_dataset = dataset.map(
172-
self.preprocess,
173-
batched=True,
174-
batch_size=preprocess_batch_size,
175-
fn_kwargs={"context_size": context_size},
176-
remove_columns=existing_columns,
177-
)
178161

179162
if pre_download:
163+
if not isinstance(dataset, Dataset):
164+
error_message = (
165+
f"Expected Hugging Face dataset to be a Dataset when pre-downloading, but got "
166+
f"{type(dataset)}."
167+
)
168+
raise TypeError(error_message)
169+
180170
# Download the whole dataset
171+
mapped_dataset = dataset.map(
172+
self.preprocess,
173+
batched=True,
174+
batch_size=preprocess_batch_size,
175+
fn_kwargs={"context_size": context_size},
176+
remove_columns=existing_columns,
177+
num_proc=n_processes_preprocessing,
178+
)
181179
self.dataset = mapped_dataset.shuffle()
182180
else:
183181
# Setup approximate shuffling. As the dataset is streamed, this just pre-downloads at
184182
# least `buffer_size` items and then shuffles just that buffer.
185183
# https://huggingface.co/docs/datasets/v2.14.5/stream#shuffle
184+
if not isinstance(dataset, IterableDataset):
185+
error_message = (
186+
f"Expected Hugging Face dataset to be an IterableDataset when streaming, but "
187+
f"got {type(dataset)}."
188+
)
189+
raise TypeError(error_message)
190+
191+
mapped_dataset = dataset.map(
192+
self.preprocess,
193+
batched=True,
194+
batch_size=preprocess_batch_size,
195+
fn_kwargs={"context_size": context_size},
196+
remove_columns=existing_columns,
197+
)
186198
self.dataset = mapped_dataset.shuffle(buffer_size=buffer_size) # type: ignore
187199

188200
@final

sparse_autoencoder/source_data/text_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
dataset_dir: str | None = None,
7373
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
7474
dataset_split: str = "train",
75+
n_processes_preprocessing: int | None = None,
7576
preprocess_batch_size: int = 1000,
7677
*,
7778
pre_download: bool = False,
@@ -94,6 +95,7 @@ def __init__(
9495
dataset_dir: Defining the `data_dir` of the dataset configuration.
9596
dataset_files: Path(s) to source data file(s).
9697
dataset_split: Dataset split (e.g., 'train').
98+
n_processes_preprocessing: Number of processes to use for preprocessing.
9799
preprocess_batch_size: Batch size for preprocessing (tokenizing prompts).
98100
pre_download: Whether to pre-download the whole dataset.
99101
"""
@@ -106,6 +108,7 @@ def __init__(
106108
dataset_files=dataset_files,
107109
dataset_path=dataset_path,
108110
dataset_split=dataset_split,
111+
n_processes_preprocessing=n_processes_preprocessing,
109112
pre_download=pre_download,
110113
preprocess_batch_size=preprocess_batch_size,
111114
)

sparse_autoencoder/train/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class Pipeline:
7979
"""Total number of activations trained on state."""
8080

8181
@final
82-
def __init__( # noqa: PLR0913
82+
def __init__(
8383
self,
8484
activation_resampler: AbstractActivationResampler | None,
8585
autoencoder: SparseAutoencoder,

sparse_autoencoder/train/sweep.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,24 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:
141141
Raises:
142142
ValueError: If the tokenizer name is not specified, but pre_tokenized is False.
143143
"""
144+
dataset_dir = (
145+
hyperparameters["source_data"]["dataset_dir"]
146+
if "dataset_dir" in hyperparameters["source_data"]
147+
else None
148+
)
149+
150+
dataset_files = (
151+
hyperparameters["source_data"]["dataset_files"]
152+
if "dataset_files" in hyperparameters["source_data"]
153+
else None
154+
)
155+
144156
if hyperparameters["source_data"]["pre_tokenized"]:
145157
return PreTokenizedDataset(
146158
dataset_path=hyperparameters["source_data"]["dataset_path"],
147159
context_size=hyperparameters["source_data"]["context_size"],
148-
dataset_dir=hyperparameters["source_data"]["dataset_dir"],
149-
dataset_files=hyperparameters["source_data"]["dataset_files"],
160+
dataset_dir=dataset_dir,
161+
dataset_files=dataset_files,
150162
)
151163

152164
if hyperparameters["source_data"]["tokenizer_name"] is None:
@@ -162,8 +174,9 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:
162174
dataset_path=hyperparameters["source_data"]["dataset_path"],
163175
context_size=hyperparameters["source_data"]["context_size"],
164176
tokenizer=tokenizer,
165-
dataset_dir=hyperparameters["source_data"]["dataset_dir"],
166-
dataset_files=hyperparameters["source_data"]["dataset_files"],
177+
dataset_dir=dataset_dir,
178+
dataset_files=dataset_files,
179+
n_processes_preprocessing=4,
167180
)
168181

169182

sparse_autoencoder/train/sweep_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,16 @@ class SourceDataHyperparameters(NestedParameter):
174174
context_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_CONTEXT_SIZE))
175175
"""Context size."""
176176

177-
dataset_dir: Parameter[str | None] = field(default=Parameter(None))
177+
dataset_dir: Parameter[str] | None = field(default=None)
178178
"""Dataset directory (within the HF dataset)"""
179179

180-
dataset_files: Parameter[str | None] = field(default=Parameter(None))
180+
dataset_files: Parameter[str] | None = field(default=None)
181181
"""Dataset files (within the HF dataset)."""
182182

183183
pre_tokenized: Parameter[bool] = field(default=Parameter(value=True))
184184
"""If the dataset is pre-tokenized."""
185185

186-
tokenizer_name: Parameter[str | None] = field(default=Parameter(None))
186+
tokenizer_name: Parameter[str] | None = field(default=None)
187187
"""Tokenizer name.
188188
189189
Only set this if the dataset is not pre-tokenized.
@@ -195,11 +195,11 @@ def __post_init__(self) -> None:
195195
Raises:
196196
ValueError: If there is an error in the source data hyperparameters.
197197
"""
198-
if self.pre_tokenized.value is False and self.tokenizer_name.value is None:
198+
if self.pre_tokenized.value is False and not isinstance(self.tokenizer_name, Parameter):
199199
error_message = "The tokenizer name must be specified, when `pre_tokenized` is False."
200200
raise ValueError(error_message)
201201

202-
if self.pre_tokenized.value is True and self.tokenizer_name.value is not None:
202+
if self.pre_tokenized.value is True and isinstance(self.tokenizer_name, Parameter):
203203
error_message = "The tokenizer name must not be set, when `pre_tokenized` is True."
204204
raise ValueError(error_message)
205205

sparse_autoencoder/train/utils/wandb_sweep_types.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC
66
from dataclasses import asdict, dataclass, is_dataclass
77
from enum import Enum, auto
8-
from typing import Any, Generic, TypeAlias, TypeVar, final
8+
from typing import Any, Generic, TypeVar, final
99

1010
from strenum import LowercaseStrEnum
1111

@@ -264,11 +264,12 @@ def __repr__(self) -> str:
264264
return self.__str__()
265265

266266

267-
OptionalFloat: TypeAlias = float | None
268-
OptionalInt: TypeAlias = int | None
269-
OptionalStr: TypeAlias = str | None
270-
271-
ParamType = TypeVar("ParamType", float, int, str, OptionalFloat, OptionalInt, OptionalStr)
267+
ParamType = TypeVar(
268+
"ParamType",
269+
float,
270+
int,
271+
str,
272+
)
272273

273274

274275
@dataclass(frozen=True)

0 commit comments

Comments
 (0)