Skip to content

Commit ad028e6

Browse files
datatrove is all you need
1 parent 6771639 commit ad028e6

File tree

9 files changed

+207
-125
lines changed

9 files changed

+207
-125
lines changed

examples/config_nanoset.yaml

+12-12
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@ checkpoints:
77
data_stages:
88
- data:
99
dataset:
10-
dataset_path: datasets/testing_alpaca_small_input_ids.npy
10+
dataset_folder: datasets/c4-es/tokenized
1111
num_loading_workers: 1
1212
seed: 42
1313
name: General purpose training (Single dataset)
1414
start_training_step: 1
1515
- data:
1616
dataset:
17-
dataset_path:
18-
- datasets/yelp_review_full_input_ids.npy
19-
- datasets/testing_alpaca_small_input_ids.npy
17+
dataset_folder:
18+
- datasets/SlimPajama-6B/tokenized
19+
- datasets/c4-es/tokenized
2020
num_loading_workers: 1
2121
seed: 42
2222
name: Second purpose training (> 1 dataset)
2323
start_training_step: 15
2424
- data:
2525
dataset:
26-
dataset_path:
27-
datasets/testing_alpaca_small_input_ids.npy: 0.8
28-
datasets/yelp_review_full_input_ids.npy: 0.2
26+
dataset_folder:
27+
datasets/SlimPajama-6B/tokenized: 0.8
28+
datasets/c4-es/tokenized: 0.2
2929
num_loading_workers: 1
3030
seed: 42
3131
name: Third purpose training (Blended dataset)
@@ -57,7 +57,7 @@ model:
5757
initializer_range: 0.02
5858
intermediate_size: 64
5959
is_llama_config: true
60-
max_position_embeddings: 256
60+
max_position_embeddings: 1024
6161
num_attention_heads: 4
6262
num_hidden_layers: 2
6363
num_key_value_heads: 4
@@ -67,7 +67,7 @@ model:
6767
rope_scaling: null
6868
tie_word_embeddings: true
6969
use_cache: true
70-
vocab_size: 32000
70+
vocab_size: 50257
7171
optimizer:
7272
accumulate_grad_in_fp32: true
7373
clip_grad: 1.0
@@ -88,11 +88,11 @@ optimizer:
8888
weight_decay: 0.01
8989
zero_stage: 0
9090
parallelism:
91-
dp: 2
91+
dp: 1
9292
expert_parallel_size: 1
9393
pp: 1
9494
pp_engine: 1f1b
95-
tp: 2
95+
tp: 1
9696
tp_linear_async_communication: true
9797
tp_mode: REDUCE_SCATTER
9898
profiler: null
@@ -105,6 +105,6 @@ tokens:
105105
limit_test_batches: 0
106106
limit_val_batches: 0
107107
micro_batch_size: 2
108-
sequence_length: 128
108+
sequence_length: 1024
109109
train_steps: 200
110110
val_check_interval: -1

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ fast-modeling = [
4949

5050
nanosets = [
5151
"transformers",
52-
"datasets",
52+
"datatrove[io,processing]",
5353
"numba",
5454
]
5555

run_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,17 @@ def get_dataloader_from_data_stage(
143143
elif isinstance(data.dataset, NanosetDatasetsArgs):
144144
# Get tokenizer cardinality
145145
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
146-
token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16
146+
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
147147
del tokenizer
148148
# Create Nanoset
149149
from nanotron.data.nanoset import Nanoset
150150

151151
with main_rank_first(trainer.parallel_context.world_pg):
152152
train_dataset = Nanoset(
153-
dataset_paths=data.dataset.dataset_path,
153+
dataset_folders=data.dataset.dataset_folder,
154154
dataset_weights=data.dataset.dataset_weights,
155155
sequence_length=trainer.sequence_length,
156-
token_dtype=token_dtype,
156+
token_size=token_size,
157157
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
158158
random_seed=data.seed,
159159
)

src/nanotron/config/config.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,18 @@ def __post_init__(self):
9393

9494
@dataclass
9595
class NanosetDatasetsArgs:
96-
dataset_path: Union[str, dict, List[str]]
96+
dataset_folder: Union[str, dict, List[str]]
9797

9898
def __post_init__(self):
99-
if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file
100-
self.dataset_path = [self.dataset_path]
99+
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file
100+
self.dataset_folder = [self.dataset_folder]
101101
self.dataset_weights = [1]
102-
elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file
102+
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file
103103
self.dataset_weights = None # Set to None so we consume all the samples randomly
104-
elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights
105-
tmp_dataset_path = self.dataset_path.copy()
106-
self.dataset_path = list(tmp_dataset_path.keys())
107-
self.dataset_weights = list(tmp_dataset_path.values())
104+
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
105+
tmp_dataset_folder = self.dataset_folder.copy()
106+
self.dataset_folder = list(tmp_dataset_folder.keys())
107+
self.dataset_weights = list(tmp_dataset_folder.values())
108108

109109

110110
@dataclass

src/nanotron/data/collator.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import dataclasses
2+
from typing import Dict, List, Union
3+
4+
import numpy as np
5+
import torch
6+
from nanotron import distributed as dist
7+
from nanotron.parallel.context import ParallelContext
8+
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
9+
10+
11+
@dataclasses.dataclass
12+
class NanosetDataCollatorForCLM:
13+
"""
14+
Data collator used for causal language modeling with Nanosets dataset.
15+
16+
- input_pp_rank: Discards last input id token
17+
- output_pp_rank: Discards first label id token
18+
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
19+
"""
20+
21+
sequence_length: int
22+
input_pp_rank: int
23+
output_pp_rank: int
24+
parallel_context: ParallelContext
25+
26+
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
27+
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
28+
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
29+
if current_pp_rank not in [
30+
self.input_pp_rank,
31+
self.output_pp_rank,
32+
]:
33+
assert all(len(example) == 0 for example in examples)
34+
return {
35+
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
36+
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
37+
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
38+
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
39+
}
40+
41+
# Make sure we load only what's necessary, ie we only load a `input_ids` column.
42+
assert all(list(example.keys()) == ["input_ids"] for example in examples)
43+
44+
# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
45+
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
46+
batch_size, expanded_input_length = input_ids.shape
47+
48+
result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}
49+
50+
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
51+
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
52+
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
53+
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
54+
55+
assert (
56+
expanded_input_length == self.sequence_length + 1
57+
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
58+
59+
# Process inputs: last token is the label
60+
if current_pp_rank == self.input_pp_rank:
61+
result["input_ids"] = input_ids[:, :-1]
62+
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
63+
64+
# Process labels: shift them to the left
65+
if current_pp_rank == self.output_pp_rank:
66+
result["label_ids"] = input_ids[:, 1:]
67+
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
68+
69+
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
70+
raise ValueError(
71+
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
72+
f" {self.sequence_length}."
73+
)
74+
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
75+
raise ValueError(
76+
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
77+
f" {self.sequence_length}."
78+
)
79+
80+
return result

src/nanotron/data/dataloader_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import nanotron.distributed as dist
22
from nanotron import logging
3+
from nanotron.data.collator import NanosetDataCollatorForCLM
34
from nanotron.dataloader import (
4-
DataCollatorForCLM,
55
EmptyInfiniteDataset,
66
get_dataloader_worker_init,
77
get_sampler,
@@ -32,7 +32,7 @@ def build_nanoset_dataloader(
3232
# No need to spawn a lot of workers, we can just use main
3333
dataloader_num_workers = 0
3434

35-
data_collator = DataCollatorForCLM(
35+
data_collator = NanosetDataCollatorForCLM(
3636
sequence_length=sequence_length,
3737
input_pp_rank=input_pp_rank,
3838
output_pp_rank=output_pp_rank,

src/nanotron/data/nanoset.py

+34-41
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import os
2+
import warnings
13
from typing import Dict, List, Tuple, Union
24

35
import numpy as np
46
import torch
7+
from datatrove.utils.dataset import DatatroveFolderDataset
58
from nanotron import logging
69
from nanotron.data.utils import count_dataset_indexes, normalize
710
from nanotron.logging import log_rank
@@ -15,49 +18,61 @@ class Nanoset(torch.utils.data.Dataset):
1518
The Nanoset dataset
1619
1720
Args:
18-
dataset_paths (List[str]): List of paths to tokenized datasets
21+
dataset_folders (List[str]): List of folders with tokenized datasets
1922
dataset_weights (List[float]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__
2023
sequence_length (int): Sequence length of the built samples
21-
token_dtype (Union[np.uint16, np.int32]): dtype of the tokens stored in the processed dataset files. np.uin16 for vocab sizes < 65535, np.int32 otherwise
24+
token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise
2225
train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size
2326
"""
2427

2528
def __init__(
2629
self,
27-
dataset_paths: List[str],
30+
dataset_folders: List[str],
2831
dataset_weights: Union[List[float], None],
2932
sequence_length: int,
30-
token_dtype: Union[np.uint16, np.int32],
33+
token_size: int,
3134
train_split_num_samples: int,
3235
random_seed: int = 1234,
3336
) -> None:
3437

38+
# Assertions
39+
if isinstance(dataset_folders, str):
40+
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
41+
dataset_folders = [dataset_folders]
42+
3543
# Init
36-
self.dataset_paths = dataset_paths
44+
self.dataset_folders = dataset_folders
3745
self.dataset_weights = dataset_weights
3846
self.sequence_length = sequence_length
39-
self.token_dtype = token_dtype
47+
self.token_size = token_size
4048
self.train_split_num_samples = train_split_num_samples
4149
self.random_seed = random_seed
50+
self.datatrove_datasets = []
51+
for dataset_folder in self.dataset_folders:
52+
self.datatrove_datasets.append(
53+
DatatroveFolderDataset(
54+
folder_path=dataset_folder,
55+
filename_pattern=os.path.join(dataset_folder, "*.ds"),
56+
seq_len=sequence_length,
57+
recursive=False,
58+
token_size=token_size,
59+
shuffle=True,
60+
)
61+
)
4262

4363
# Build Nanoset Index
4464
## To build the index we need the length of each dataset
45-
self.dataset_lengths = []
46-
for dataset_path in self.dataset_paths:
47-
self.dataset_buffer_mmap = np.memmap(dataset_path, mode="r", order="C", dtype=self.token_dtype)
48-
self.dataset_buffer = memoryview(self.dataset_buffer_mmap)
49-
dataset_number_of_tokens = int(len(self.dataset_buffer))
50-
number_of_samples = int(
51-
(dataset_number_of_tokens - 1) / sequence_length
52-
) # Discard last sample if length < sequence_length
53-
self.dataset_lengths.append(number_of_samples)
65+
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
5466
## Set dataset weights
5567
if (
5668
self.dataset_weights is None
5769
): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch
5870
self.dataset_weights = normalize(self.dataset_lengths)
5971
else:
6072
self.dataset_weights = normalize(dataset_weights)
73+
assert len(dataset_folders) == len(
74+
self.dataset_weights
75+
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
6176
## Build dataset index and dataset sample index
6277
self.dataset_index, self.dataset_sample_index = self.build_nanoset_index()
6378

@@ -79,25 +94,12 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
7994
idx (int): The index into the dataset
8095
8196
Returns:
82-
Dict[str, numpy.ndarray]: The input ids wrapped in a dictionary
97+
Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary
8398
"""
84-
8599
dataset = self.dataset_index[idx]
86100
dataset_sample = self.dataset_sample_index[idx]
87101

88-
# Rebuild the memmap in every access to free memory
89-
# https://stackoverflow.com/a/61472122
90-
self.dataset_buffer_mmap = np.memmap(self.dataset_paths[dataset], mode="r", order="C", dtype=self.token_dtype)
91-
self.dataset_buffer = memoryview(self.dataset_buffer_mmap)
92-
93-
# uint16 -> 2 bytes per token, int32 -> 4 bytes per token
94-
offset = dataset_sample * self.sequence_length * (np.iinfo(self.token_dtype).bits / 8)
95-
input_ids_tokens = np.frombuffer(
96-
self.dataset_buffer, dtype=self.token_dtype, count=(self.sequence_length + 1), offset=int(offset)
97-
)
98-
99-
# Return tokens as np.int32 as Torch can't handle uint16
100-
return {"input_ids": input_ids_tokens.astype(np.int32)}
102+
return self.datatrove_datasets[dataset][dataset_sample]
101103

102104
def build_nanoset_index(self) -> np.ndarray:
103105
"""
@@ -124,15 +126,6 @@ def build_nanoset_index(self) -> np.ndarray:
124126

125127
return dataset_index, dataset_sample_index
126128

127-
def __del__(self) -> None:
128-
"""
129-
Clean up Nanoset
130-
"""
131-
132-
if hasattr(self, "dataset_buffer_mmap"):
133-
self.dataset_buffer_mmap._mmap.close()
134-
del self.dataset_buffer_mmap
135-
136129
def print_nanoset_info(self):
137130

138131
log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0)
@@ -141,10 +134,10 @@ def print_nanoset_info(self):
141134
)
142135

143136
# Print samples from each dataset + weight
144-
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_paths))
137+
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
145138
for index, sample_count in enumerate(dataset_sample_count):
146139
log_rank(
147-
f"> Total number of samples from the {self.dataset_paths[index].rsplit('/', 1)[-1]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
140+
f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
148141
logger=logger,
149142
level=logging.INFO,
150143
rank=0,

0 commit comments

Comments
 (0)