Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to customize dataloader #836

Open
wants to merge 6 commits into
base: gh/fegin/11/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import build_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader


class TestDatasetCheckpointing:
Expand Down Expand Up @@ -41,13 +40,12 @@ def test_c4_resumption(self):
def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
return build_hf_data_loader(
return build_hf_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
tokenizer_path="./tests/assets/test_tiktoken.model",
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
dp_world_size=4,
dp_rank=0,
)
3 changes: 3 additions & 0 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_dataloader
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
Expand Down Expand Up @@ -60,6 +61,7 @@ def test_register_train_spec(self):
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
Expand All @@ -78,6 +80,7 @@ def test_optim_hook(self):
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")
Expand Down
135 changes: 135 additions & 0 deletions torchtitan/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import pickle
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Protocol

from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer


@dataclass
class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.

This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
``state_dict()`` and ``load_state_dict()``.
"""

tokenizer: Tokenizer
dp_rank: int
dp_world_size: int
batch_size: int
Comment on lines +21 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely convincing to me if this is the right basic protocol:

  1. What if people don't care if it supports checkpointing or not
  2. what if people don't need it to be aware of DP ranks
  3. Do all data loaders need to perform tokenization?

I feel that as along as it's an iterator, it's good enough. The things returned from the iterator don't need to be input_ids, labels. For sequence masking / multimodal, a dataloader needs to return more, e.g. mask, images, etc.

Copy link
Contributor Author

@fegin fegin Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think item 2 and 3 are reasonable. And I actually think we can remove them. But what's the point of not supporting checkpointing? For torchdata, which is a generic dataset library, it makes sense to have a very basic dataloader class. TorchTitan is a distributed training library, I don't see a reason why supporting checkpointing during training is not a must.

Our checkpoint manager also assume dataloader to be Stateful. Removing Stateful is too relaxed, imo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm more thinking from an experimentation platform perspective.
E.g. I'm a researcher working on a new type of attention. I have some data I'd like to load; I want to look at throughput gain; but I don't care about fault tolerance.


@abstractmethod
def __iter__(self):
...


class DPDataLoader(StatefulDataLoader, BaseDataLoader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know but not really a big fan of this name. Maybe ParallelAwareDataloader?

"""Dataloader that is aware of distributed data parallelism.

This dataloader is used to load data in a distributed data parallel fashion. It also
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
methods such as ``__iter__``.

Args:
dataset (IterableDataset): The dataset to iterate over.
tokenizer (Tokenizer): The tokenizer to use to tokenize the dataset.
dp_rank: Data parallelism rank for this dataloader.
dp_world_size: The world size of the data parallelism.
batch_size: The batch size to use for each iteration.
"""

def __init__(
self,
dataset: IterableDataset,
tokenizer: Tokenizer,
dp_rank: int,
dp_world_size: int,
batch_size: int,
):
BaseDataLoader.__init__(
self,
tokenizer=tokenizer,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
StatefulDataLoader.__init__(self, dataset, batch_size)
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
# We don't have to use pickle as DCP will serialize the state_dict. However,
# we have to keep this for backward compatibility.
self._rank_id: pickle.dumps(StatefulDataLoader.state_dict(self)),
"world_size": self.dp_world_size,
}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# State being empty is valid.
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self.dp_rank}, "
"expected key {self._rank_id}"
)
return

assert self.dp_world_size == state_dict["world_size"], (
"dp_degree is inconsistent before and after checkpoint, "
"dataloader resharding is not supported yet."
)
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
# keep this for backward compatibility.
StatefulDataLoader.load_state_dict(
self, pickle.loads(state_dict[self._rank_id])
)


class DataLoaderBuilder(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: is the signature of this protocol too strict?
Is it too relaxed if we substitute it to something like a Callable from anything to Iterator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one, an alternative in my mind was Callable[[...], BaseDataloader]. This just relaxs the input and users can define whatever they want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me

"""This is a protocol to annoate ``build_dataloader_fn``.

While mypy.extensions provides Arg to annotate the name, it requires another dependency on
mypy-extensions. Mypy also supports this annonation and it is easier to read.
"""

def __call__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer_path: str,
batch_size: int,
seq_len: int,
dp_rank: int,
dp_world_size: int,
) -> BaseDataLoader:
"""Function call

Args:
dataset_name (str): Name of the dataset to iterate over.
dataset_path (Optional[str]): Path to the dataset to load.
tokenizer_path (str): Path to the tokenizer to use.
batch_size (int): The batch size to use for each iteration.
seq_len (int): Sequence length for each batch.
dp_rank (int): Data parallelism rank for this dataloader.
dp_world_size (int): The world size of the data parallelism.

Returns:
BaseDataLoader: The dataloader.
"""
...
4 changes: 2 additions & 2 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import build_tokenizer

__all__ = [
"build_hf_data_loader",
"build_hf_dataloader",
"build_tokenizer",
]
91 changes: 34 additions & 57 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,28 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pickle
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger
from torchtitan.dataloader import DPDataLoader

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchtitan.datasets.tokenizer import build_tokenizer, Tokenizer
from torchtitan.logging import logger


def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def _process_c4_text(sample: Dict[str, Any]) -> str:
def _process_c4_text(sample: dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]

Expand Down Expand Up @@ -75,8 +75,8 @@ def __init__(
dataset_path: Optional[str],
tokenizer: Tokenizer,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
Expand All @@ -88,15 +88,15 @@ def __init__(
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this is indeed needed. When tracing the code.. I need to double check its definition in the caller, although the comment specified that.

self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []
self._all_tokens: list[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
Expand Down Expand Up @@ -142,56 +142,33 @@ def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
# raise error if dp_word_size does not match.
self._world_size = world_size

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self._world_size,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
)
return
assert (
self._world_size == state_dict["world_size"]
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
def build_hf_dataloader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: Tokenizer,
tokenizer_path: str,
batch_size: int,
seq_len: int,
world_size: int,
rank: int,
dp_rank: int,
dp_world_size: int,
infinite: bool = True,
):
) -> DPDataLoader:
"""Build a data loader for HuggingFace datasets."""
tokenizer = build_tokenizer("tiktoken", tokenizer_path)

hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)

return DPDataLoader(
dataset=hf_ds,
tokenizer=tokenizer,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)
3 changes: 0 additions & 3 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,3 @@
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401


model_name_to_tokenizer = {"llama3": "tiktoken"}
2 changes: 2 additions & 0 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.datasets import build_hf_dataloader
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec
Expand Down Expand Up @@ -65,5 +66,6 @@
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
)
)
Loading
Loading