@@ -111,6 +111,7 @@ def __init__(
111
111
buffer_size : int = 1000 ,
112
112
dataset_dir : str | None = None ,
113
113
dataset_files : str | Sequence [str ] | Mapping [str , str | Sequence [str ]] | None = None ,
114
+ n_processes_preprocessing : int | None = None ,
114
115
preprocess_batch_size : int = 1000 ,
115
116
* ,
116
117
pre_download : bool = False ,
@@ -135,6 +136,7 @@ def __init__(
135
136
tokenized prompts once the preprocessing function has been applied.
136
137
dataset_dir: Defining the `data_dir` of the dataset configuration.
137
138
dataset_files: Path(s) to source data file(s).
139
+ n_processes_preprocessing: The number of processes to use for preprocessing.
138
140
preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
139
141
tokenizing prompts).
140
142
pre_download: Whether to pre-download the whole dataset.
@@ -146,43 +148,53 @@ def __init__(
146
148
147
149
# Load the dataset
148
150
should_stream = not pre_download
149
- loaded_dataset = load_dataset (
151
+ dataset = load_dataset (
150
152
dataset_path ,
151
153
streaming = should_stream ,
152
154
split = dataset_split ,
153
155
data_dir = dataset_dir ,
154
156
data_files = dataset_files ,
155
157
)
156
158
157
- # Check the dataset is a Hugging Face Dataset or IterableDataset
158
- if not isinstance (loaded_dataset , Dataset ) and not isinstance (
159
- loaded_dataset , IterableDataset
160
- ):
161
- error_message = (
162
- f"Expected Hugging Face dataset to be a Dataset or IterableDataset, but got "
163
- f"{ type (loaded_dataset )} ."
164
- )
165
- raise TypeError (error_message )
166
-
167
- dataset : Dataset | IterableDataset = loaded_dataset
168
-
169
159
# Setup preprocessing
170
160
existing_columns : list [str ] = list (next (iter (dataset )).keys ())
171
- mapped_dataset = dataset .map (
172
- self .preprocess ,
173
- batched = True ,
174
- batch_size = preprocess_batch_size ,
175
- fn_kwargs = {"context_size" : context_size },
176
- remove_columns = existing_columns ,
177
- )
178
161
179
162
if pre_download :
163
+ if not isinstance (dataset , Dataset ):
164
+ error_message = (
165
+ f"Expected Hugging Face dataset to be a Dataset when pre-downloading, but got "
166
+ f"{ type (dataset )} ."
167
+ )
168
+ raise TypeError (error_message )
169
+
180
170
# Download the whole dataset
171
+ mapped_dataset = dataset .map (
172
+ self .preprocess ,
173
+ batched = True ,
174
+ batch_size = preprocess_batch_size ,
175
+ fn_kwargs = {"context_size" : context_size },
176
+ remove_columns = existing_columns ,
177
+ num_proc = n_processes_preprocessing ,
178
+ )
181
179
self .dataset = mapped_dataset .shuffle ()
182
180
else :
183
181
# Setup approximate shuffling. As the dataset is streamed, this just pre-downloads at
184
182
# least `buffer_size` items and then shuffles just that buffer.
185
183
# https://huggingface.co/docs/datasets/v2.14.5/stream#shuffle
184
+ if not isinstance (dataset , IterableDataset ):
185
+ error_message = (
186
+ f"Expected Hugging Face dataset to be an IterableDataset when streaming, but "
187
+ f"got { type (dataset )} ."
188
+ )
189
+ raise TypeError (error_message )
190
+
191
+ mapped_dataset = dataset .map (
192
+ self .preprocess ,
193
+ batched = True ,
194
+ batch_size = preprocess_batch_size ,
195
+ fn_kwargs = {"context_size" : context_size },
196
+ remove_columns = existing_columns ,
197
+ )
186
198
self .dataset = mapped_dataset .shuffle (buffer_size = buffer_size ) # type: ignore
187
199
188
200
@final
0 commit comments