Skip to content

Commit 5542f56

Browse files
authored
Allow users to customize dataloader (#843)
Land #836
1 parent 0b0931c commit 5542f56

File tree

12 files changed

+155
-105
lines changed

12 files changed

+155
-105
lines changed

scripts/estimate/estimation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616

1717
from torchtitan import utils
1818
from torchtitan.config_manager import JobConfig
19-
from torchtitan.datasets import build_tokenizer
2019
from torchtitan.float8 import Float8Handler
2120
from torchtitan.logging import init_logger, logger
22-
from torchtitan.models import model_name_to_tokenizer
2321
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
2422
from torchtitan.parallelisms import ParallelDims
2523
from torchtitan.train_spec import get_train_spec
@@ -83,8 +81,7 @@ def estimate_memory(job_config: JobConfig):
8381
model_name = job_config.model.name
8482

8583
# build tokenizer
86-
tokenizer_type = model_name_to_tokenizer[model_name]
87-
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
84+
tokenizer = train_spec.tokenizer_cls(job_config.model.tokenizer_path)
8885

8986
train_context = utils.get_train_context(
9087
parallel_dims.loss_parallel_enabled,

scripts/generate/test_generate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727

2828
from torchtitan import utils
2929
from torchtitan.config_manager import JobConfig
30-
from torchtitan.datasets import build_tokenizer
3130
from torchtitan.logging import init_logger, logger
3231
from torchtitan.metrics import build_device_memory_monitor
33-
from torchtitan.models import model_name_to_tokenizer
3432
from torchtitan.parallelisms import ParallelDims
3533

3634
from torchtitan.train_spec import get_train_spec
@@ -108,10 +106,7 @@ def test_generate(
108106
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")
109107

110108
# Tokenizer setup
111-
tokenizer = build_tokenizer(
112-
model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path
113-
)
114-
109+
tokenizer = train_spec.tokenizer_cls(config.model.tokenizer_path)
115110
model_config = train_spec.config[config.model.flavor]
116111
model_config.norm_type = config.model.norm_type
117112
model_config.max_seq_len = config.training.seq_len

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torchtitan.datasets.hf_datasets import build_hf_data_loader
9-
from torchtitan.datasets.tokenizer import build_tokenizer
8+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
9+
from torchtitan.datasets.tokenizer import TikTokenizer
1010

1111

1212
class TestDatasetCheckpointing:
@@ -41,13 +41,13 @@ def test_c4_resumption(self):
4141
def _build_dataloader(
4242
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
4343
):
44-
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
45-
return build_hf_data_loader(
44+
tokenizer = TikTokenizer("./tests/assets/test_tiktoken.model")
45+
return build_hf_dataloader(
4646
dataset_name=dataset_name,
4747
dataset_path=dataset_path,
4848
tokenizer=tokenizer,
4949
batch_size=1,
5050
seq_len=1024,
51-
world_size=4,
52-
rank=0,
51+
dp_world_size=4,
52+
dp_rank=0,
5353
)

tests/unit_tests/test_train_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch
1111
import torch.nn as nn
1212
from torchtitan.config_manager import JobConfig
13+
from torchtitan.datasets import build_hf_dataloader
14+
from torchtitan.datasets.tokenizer import TikTokenizer
1315
from torchtitan.models.llama import parallelize_llama, pipeline_llama
1416
from torchtitan.optimizer import (
1517
build_lr_schedulers,
@@ -60,6 +62,8 @@ def test_register_train_spec(self):
6062
pipelining_fn=pipeline_llama,
6163
build_optimizers_fn=build_optimizers,
6264
build_lr_schedulers_fn=build_lr_schedulers,
65+
build_dataloader_fn=build_hf_dataloader,
66+
tokenizer_cls=TikTokenizer,
6367
)
6468
register_train_spec(spec)
6569
new_spec = get_train_spec("fake")
@@ -78,6 +82,8 @@ def test_optim_hook(self):
7882
pipelining_fn=pipeline_llama,
7983
build_optimizers_fn=fake_build_optimizers,
8084
build_lr_schedulers_fn=build_lr_schedulers,
85+
build_dataloader_fn=build_hf_dataloader,
86+
tokenizer_cls=TikTokenizer,
8187
)
8288
register_train_spec(spec)
8389
new_spec = get_train_spec("fake2")

torchtitan/dataloader.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
import pickle
10+
from abc import ABC, abstractmethod
11+
from typing import Any, Callable, TypeAlias
12+
13+
from torch.distributed.checkpoint.stateful import Stateful
14+
from torch.utils.data import IterableDataset
15+
from torchdata.stateful_dataloader import StatefulDataLoader
16+
17+
18+
class BaseDataLoader(Stateful, ABC):
19+
"""Base class for all dataloaders.
20+
21+
This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
22+
``state_dict()`` and ``load_state_dict()``.
23+
"""
24+
25+
@abstractmethod
26+
def __iter__(self):
27+
...
28+
29+
30+
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
31+
"""Dataloader that is aware of distributed data parallelism.
32+
33+
This dataloader is used to load data in a distributed data parallel fashion. It also
34+
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
35+
methods such as ``__iter__``.
36+
37+
Args:
38+
dataset (IterableDataset): The dataset to iterate over.
39+
dp_rank: Data parallelism rank for this dataloader.
40+
dp_world_size: The world size of the data parallelism.
41+
batch_size: The batch size to use for each iteration.
42+
"""
43+
44+
dp_rank: int
45+
dp_world_size: int
46+
batch_size: int
47+
48+
def __init__(
49+
self,
50+
dataset: IterableDataset,
51+
dp_rank: int,
52+
dp_world_size: int,
53+
batch_size: int,
54+
):
55+
self.dp_world_size = dp_world_size
56+
self.dp_rank = dp_rank
57+
self.batch_size = batch_size
58+
super().__init__(dataset, batch_size)
59+
self._rank_id = f"dp_rank_{dp_rank}"
60+
61+
def state_dict(self) -> dict[str, Any]:
62+
# Store state only for dp rank to avoid replicating the same state across other dimensions.
63+
return {
64+
# We don't have to use pickle as DCP will serialize the state_dict. However,
65+
# we have to keep this for backward compatibility.
66+
self._rank_id: pickle.dumps(super().state_dict()),
67+
"world_size": self.dp_world_size,
68+
}
69+
70+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
71+
# State being empty is valid.
72+
if not state_dict:
73+
return
74+
75+
if self._rank_id not in state_dict:
76+
logger.warning(
77+
f"DataLoader state is empty for dp rank {self.dp_rank}, "
78+
"expected key {self._rank_id}"
79+
)
80+
return
81+
82+
assert self.dp_world_size == state_dict["world_size"], (
83+
"dp_degree is inconsistent before and after checkpoint, "
84+
"dataloader resharding is not supported yet."
85+
)
86+
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
87+
# keep this for backward compatibility.
88+
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
89+
90+
91+
DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader]

torchtitan/datasets/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.datasets.hf_datasets import build_hf_data_loader
8-
from torchtitan.datasets.tokenizer import build_tokenizer
7+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
98

10-
__all__ = [
11-
"build_hf_data_loader",
12-
"build_tokenizer",
13-
]
9+
__all__ = ["build_hf_dataloader"]

torchtitan/datasets/hf_datasets.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import pickle
87
from dataclasses import dataclass
9-
from typing import Any, Callable, Dict, List, Optional
8+
from typing import Any, Callable, Optional
109

1110
import torch
1211

1312
from datasets import Dataset, load_dataset
1413
from datasets.distributed import split_dataset_by_node
1514
from torch.distributed.checkpoint.stateful import Stateful
1615
from torch.utils.data import IterableDataset
17-
from torchdata.stateful_dataloader import StatefulDataLoader
1816

17+
from torchtitan.dataloader import ParallelAwareDataloader
1918
from torchtitan.datasets.tokenizer import Tokenizer
2019
from torchtitan.logging import logger
2120

@@ -25,7 +24,7 @@ def _load_c4_dataset(dataset_path: str):
2524
return load_dataset(dataset_path, name="en", split="train", streaming=True)
2625

2726

28-
def _process_c4_text(sample: Dict[str, Any]) -> str:
27+
def _process_c4_text(sample: dict[str, Any]) -> str:
2928
"""Process C4 dataset sample text."""
3029
return sample["text"]
3130

@@ -75,8 +74,8 @@ def __init__(
7574
dataset_path: Optional[str],
7675
tokenizer: Tokenizer,
7776
seq_len: int = 2048,
78-
world_size: int = 1,
79-
rank: int = 0,
77+
dp_rank: int = 0,
78+
dp_world_size: int = 1,
8079
infinite: bool = False,
8180
) -> None:
8281
# Force lowercase for consistent comparison
@@ -88,15 +87,15 @@ def __init__(
8887
ds = dataset_loader(path)
8988

9089
self.dataset_name = dataset_name
91-
self._data = split_dataset_by_node(ds, rank, world_size)
90+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
9291
self._tokenizer = tokenizer
9392
self.seq_len = seq_len
9493
self.infinite = infinite
9594
self._text_processor = text_processor
9695

9796
# Variables for checkpointing
9897
self._sample_idx = 0
99-
self._all_tokens: List[int] = []
98+
self._all_tokens: list[int] = []
10099

101100
def _get_data_iter(self):
102101
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
@@ -142,56 +141,31 @@ def state_dict(self):
142141
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
143142

144143

145-
class DPAwareDataLoader(StatefulDataLoader, Stateful):
146-
"""
147-
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
148-
"""
149-
150-
def __init__(
151-
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
152-
):
153-
super().__init__(hf_ds, batch_size)
154-
self._dp_rank = dp_rank
155-
self._rank_id = f"dp_rank_{dp_rank}"
156-
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
157-
# raise error if dp_word_size does not match.
158-
self._world_size = world_size
159-
160-
def state_dict(self) -> Dict[str, Any]:
161-
# Store state only for dp rank to avoid replicating the same state across other dimensions
162-
return {
163-
self._rank_id: pickle.dumps(super().state_dict()),
164-
"world_size": self._world_size,
165-
}
166-
167-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
168-
# State being empty is valid
169-
if not state_dict:
170-
return
171-
172-
if self._rank_id not in state_dict:
173-
logger.warning(
174-
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
175-
)
176-
return
177-
assert (
178-
self._world_size == state_dict["world_size"]
179-
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
180-
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
181-
182-
183-
def build_hf_data_loader(
144+
def build_hf_dataloader(
184145
dataset_name: str,
185146
dataset_path: Optional[str],
186147
tokenizer: Tokenizer,
187148
batch_size: int,
188149
seq_len: int,
189-
world_size: int,
190-
rank: int,
150+
dp_rank: int,
151+
dp_world_size: int,
191152
infinite: bool = True,
192-
):
153+
) -> ParallelAwareDataloader:
193154
"""Build a data loader for HuggingFace datasets."""
155+
194156
hf_ds = HuggingFaceDataset(
195-
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
157+
dataset_name=dataset_name,
158+
dataset_path=dataset_path,
159+
tokenizer=tokenizer,
160+
seq_len=seq_len,
161+
dp_rank=dp_rank,
162+
dp_world_size=dp_world_size,
163+
infinite=infinite,
164+
)
165+
166+
return ParallelAwareDataloader(
167+
dataset=hf_ds,
168+
dp_rank=dp_rank,
169+
dp_world_size=dp_world_size,
170+
batch_size=batch_size,
196171
)
197-
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)

torchtitan/datasets/tokenizer/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,5 @@
77
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
88
from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
99

10-
from torchtitan.logging import logger
1110

12-
13-
def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
14-
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
15-
if tokenizer_type == "tiktoken":
16-
return TikTokenizer(tokenizer_path)
17-
else:
18-
raise ValueError(f"Unknown tokenizer type: {tokenizer_type}")
11+
__all__ = ["Tokenizer", "TikTokenizer"]

torchtitan/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,3 @@
88
# Import the built-in models here so that the corresponding register_model_spec()
99
# will be called.
1010
import torchtitan.models.llama # noqa: F401
11-
12-
13-
model_name_to_tokenizer = {"llama3": "tiktoken"}

torchtitan/models/llama/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
from torchtitan.datasets import build_hf_dataloader
10+
from torchtitan.datasets.tokenizer import TikTokenizer
911
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
1012
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
1113
from torchtitan.train_spec import register_train_spec, TrainSpec
@@ -65,5 +67,7 @@
6567
pipelining_fn=pipeline_llama,
6668
build_optimizers_fn=build_optimizers,
6769
build_lr_schedulers_fn=build_lr_schedulers,
70+
build_dataloader_fn=build_hf_dataloader,
71+
tokenizer_cls=TikTokenizer,
6872
)
6973
)

0 commit comments

Comments
 (0)