Skip to content

Allow users to customize dataloader #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_tokenizer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec
Expand Down Expand Up @@ -83,8 +81,7 @@ def estimate_memory(job_config: JobConfig):
model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
tokenizer = train_spec.tokenizer_cls(job_config.model.tokenizer_path)

train_context = utils.get_train_context(
parallel_dims.loss_parallel_enabled,
Expand Down
7 changes: 1 addition & 6 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor
from torchtitan.models import model_name_to_tokenizer
from torchtitan.parallelisms import ParallelDims

from torchtitan.train_spec import get_train_spec
Expand Down Expand Up @@ -108,10 +106,7 @@ def test_generate(
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")

# Tokenizer setup
tokenizer = build_tokenizer(
model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path
)

tokenizer = train_spec.tokenizer_cls(config.model.tokenizer_path)
model_config = train_spec.config[config.model.flavor]
model_config.norm_type = config.model.norm_type
model_config.max_seq_len = config.training.seq_len
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import build_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import TikTokenizer


class TestDatasetCheckpointing:
Expand Down Expand Up @@ -41,13 +41,13 @@ def test_c4_resumption(self):
def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
return build_hf_data_loader(
tokenizer = TikTokenizer("./tests/assets/test_tiktoken.model")
return build_hf_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
dp_world_size=4,
dp_rank=0,
)
6 changes: 6 additions & 0 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import TikTokenizer
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
Expand Down Expand Up @@ -60,6 +62,8 @@ def test_register_train_spec(self):
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
tokenizer_cls=TikTokenizer,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
Expand All @@ -78,6 +82,8 @@ def test_optim_hook(self):
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
tokenizer_cls=TikTokenizer,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")
Expand Down
91 changes: 91 additions & 0 deletions torchtitan/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import pickle
from abc import ABC, abstractmethod
from typing import Any, Callable, TypeAlias

from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader


class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.

This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
``state_dict()`` and ``load_state_dict()``.
"""

@abstractmethod
def __iter__(self):
...


class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
"""Dataloader that is aware of distributed data parallelism.

This dataloader is used to load data in a distributed data parallel fashion. It also
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
methods such as ``__iter__``.

Args:
dataset (IterableDataset): The dataset to iterate over.
dp_rank: Data parallelism rank for this dataloader.
dp_world_size: The world size of the data parallelism.
batch_size: The batch size to use for each iteration.
"""

dp_rank: int
dp_world_size: int
batch_size: int

def __init__(
self,
dataset: IterableDataset,
dp_rank: int,
dp_world_size: int,
batch_size: int,
):
self.dp_world_size = dp_world_size
self.dp_rank = dp_rank
self.batch_size = batch_size
super().__init__(dataset, batch_size)
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
# We don't have to use pickle as DCP will serialize the state_dict. However,
# we have to keep this for backward compatibility.
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self.dp_world_size,
}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# State being empty is valid.
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self.dp_rank}, "
"expected key {self._rank_id}"
)
return

assert self.dp_world_size == state_dict["world_size"], (
"dp_degree is inconsistent before and after checkpoint, "
"dataloader resharding is not supported yet."
)
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
# keep this for backward compatibility.
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader]
8 changes: 2 additions & 6 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import build_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader

__all__ = [
"build_hf_data_loader",
"build_tokenizer",
]
__all__ = ["build_hf_dataloader"]
78 changes: 26 additions & 52 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pickle
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.dataloader import ParallelAwareDataloader
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

Expand All @@ -25,7 +24,7 @@ def _load_c4_dataset(dataset_path: str):
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def _process_c4_text(sample: Dict[str, Any]) -> str:
def _process_c4_text(sample: dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]

Expand Down Expand Up @@ -75,8 +74,8 @@ def __init__(
dataset_path: Optional[str],
tokenizer: Tokenizer,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
Expand All @@ -88,15 +87,15 @@ def __init__(
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []
self._all_tokens: list[int] = []

def _get_data_iter(self):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
Expand Down Expand Up @@ -142,56 +141,31 @@ def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
# raise error if dp_word_size does not match.
self._world_size = world_size

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self._world_size,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
)
return
assert (
self._world_size == state_dict["world_size"]
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
def build_hf_dataloader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size: int,
rank: int,
dp_rank: int,
dp_world_size: int,
infinite: bool = True,
):
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""

hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)

return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)
9 changes: 1 addition & 8 deletions torchtitan/datasets/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,5 @@
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
from torchtitan.datasets.tokenizer.tokenizer import Tokenizer

from torchtitan.logging import logger


def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "tiktoken":
return TikTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {tokenizer_type}")
__all__ = ["Tokenizer", "TikTokenizer"]
3 changes: 0 additions & 3 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,3 @@
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401


model_name_to_tokenizer = {"llama3": "tiktoken"}
4 changes: 4 additions & 0 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import TikTokenizer
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec
Expand Down Expand Up @@ -65,5 +67,7 @@
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
tokenizer_cls=TikTokenizer,
)
)
Loading
Loading