-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow users to customize dataloader #836
base: gh/fegin/11/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
import pickle | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Optional, Protocol | ||
|
||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import IterableDataset | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
from torchtitan.datasets.tokenizer import Tokenizer | ||
|
||
|
||
@dataclass | ||
class BaseDataLoader(Stateful, ABC): | ||
"""Base class for all dataloaders. | ||
|
||
This is used to enforce that all dataloaders have the methods defined in ``Stateful``, | ||
``state_dict()`` and ``load_state_dict()``. | ||
""" | ||
|
||
tokenizer: Tokenizer | ||
dp_rank: int | ||
dp_world_size: int | ||
batch_size: int | ||
|
||
@abstractmethod | ||
def __iter__(self): | ||
... | ||
|
||
|
||
class DPDataLoader(StatefulDataLoader, BaseDataLoader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know but not really a big fan of this name. Maybe |
||
"""Dataloader that is aware of distributed data parallelism. | ||
|
||
This dataloader is used to load data in a distributed data parallel fashion. It also | ||
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary | ||
methods such as ``__iter__``. | ||
|
||
Args: | ||
dataset (IterableDataset): The dataset to iterate over. | ||
tokenizer (Tokenizer): The tokenizer to use to tokenize the dataset. | ||
dp_rank: Data parallelism rank for this dataloader. | ||
dp_world_size: The world size of the data parallelism. | ||
batch_size: The batch size to use for each iteration. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: IterableDataset, | ||
tokenizer: Tokenizer, | ||
dp_rank: int, | ||
dp_world_size: int, | ||
batch_size: int, | ||
): | ||
BaseDataLoader.__init__( | ||
self, | ||
tokenizer=tokenizer, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
batch_size=batch_size, | ||
) | ||
StatefulDataLoader.__init__(self, dataset, batch_size) | ||
self._rank_id = f"dp_rank_{dp_rank}" | ||
|
||
def state_dict(self) -> dict[str, Any]: | ||
# Store state only for dp rank to avoid replicating the same state across other dimensions. | ||
return { | ||
# We don't have to use pickle as DCP will serialize the state_dict. However, | ||
# we have to keep this for backward compatibility. | ||
self._rank_id: pickle.dumps(StatefulDataLoader.state_dict(self)), | ||
"world_size": self.dp_world_size, | ||
} | ||
|
||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
# State being empty is valid. | ||
if not state_dict: | ||
return | ||
|
||
if self._rank_id not in state_dict: | ||
logger.warning( | ||
f"DataLoader state is empty for dp rank {self.dp_rank}, " | ||
"expected key {self._rank_id}" | ||
) | ||
return | ||
|
||
assert self.dp_world_size == state_dict["world_size"], ( | ||
"dp_degree is inconsistent before and after checkpoint, " | ||
"dataloader resharding is not supported yet." | ||
) | ||
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to | ||
# keep this for backward compatibility. | ||
StatefulDataLoader.load_state_dict( | ||
self, pickle.loads(state_dict[self._rank_id]) | ||
) | ||
|
||
|
||
class DataLoaderBuilder(Protocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: is the signature of this protocol too strict? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this one, an alternative in my mind was There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good to me |
||
"""This is a protocol to annoate ``build_dataloader_fn``. | ||
|
||
While mypy.extensions provides Arg to annotate the name, it requires another dependency on | ||
mypy-extensions. Mypy also supports this annonation and it is easier to read. | ||
""" | ||
|
||
def __call__( | ||
self, | ||
dataset_name: str, | ||
dataset_path: Optional[str], | ||
tokenizer_path: str, | ||
batch_size: int, | ||
seq_len: int, | ||
dp_rank: int, | ||
dp_world_size: int, | ||
) -> BaseDataLoader: | ||
"""Function call | ||
|
||
Args: | ||
dataset_name (str): Name of the dataset to iterate over. | ||
dataset_path (Optional[str]): Path to the dataset to load. | ||
tokenizer_path (str): Path to the tokenizer to use. | ||
batch_size (int): The batch size to use for each iteration. | ||
seq_len (int): Sequence length for each batch. | ||
dp_rank (int): Data parallelism rank for this dataloader. | ||
dp_world_size (int): The world size of the data parallelism. | ||
|
||
Returns: | ||
BaseDataLoader: The dataloader. | ||
""" | ||
... |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,28 +4,28 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pickle | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, List, Optional | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
|
||
from datasets import Dataset, load_dataset | ||
from datasets.distributed import split_dataset_by_node | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import IterableDataset | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
from torchtitan.datasets.tokenizer import Tokenizer | ||
from torchtitan.logging import logger | ||
from torchtitan.dataloader import DPDataLoader | ||
|
||
from datasets import Dataset, load_dataset | ||
from datasets.distributed import split_dataset_by_node | ||
from torchtitan.datasets.tokenizer import build_tokenizer, Tokenizer | ||
from torchtitan.logging import logger | ||
|
||
|
||
def _load_c4_dataset(dataset_path: str): | ||
"""Load C4 dataset with default configuration.""" | ||
return load_dataset(dataset_path, name="en", split="train", streaming=True) | ||
|
||
|
||
def _process_c4_text(sample: Dict[str, Any]) -> str: | ||
def _process_c4_text(sample: dict[str, Any]) -> str: | ||
"""Process C4 dataset sample text.""" | ||
return sample["text"] | ||
|
||
|
@@ -75,8 +75,8 @@ def __init__( | |
dataset_path: Optional[str], | ||
tokenizer: Tokenizer, | ||
seq_len: int = 2048, | ||
world_size: int = 1, | ||
rank: int = 0, | ||
dp_rank: int = 0, | ||
dp_world_size: int = 1, | ||
infinite: bool = False, | ||
) -> None: | ||
# Force lowercase for consistent comparison | ||
|
@@ -88,15 +88,15 @@ def __init__( | |
ds = dataset_loader(path) | ||
|
||
self.dataset_name = dataset_name | ||
self._data = split_dataset_by_node(ds, rank, world_size) | ||
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, this is indeed needed. When tracing the code.. I need to double check its definition in the caller, although the comment specified that. |
||
self._tokenizer = tokenizer | ||
self.seq_len = seq_len | ||
self.infinite = infinite | ||
self._text_processor = text_processor | ||
|
||
# Variables for checkpointing | ||
self._sample_idx = 0 | ||
self._all_tokens: List[int] = [] | ||
self._all_tokens: list[int] = [] | ||
|
||
def _get_data_iter(self): | ||
if self._sample_idx == 0: | ||
|
@@ -142,56 +142,33 @@ def state_dict(self): | |
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} | ||
|
||
|
||
class DPAwareDataLoader(StatefulDataLoader, Stateful): | ||
""" | ||
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. | ||
""" | ||
|
||
def __init__( | ||
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int | ||
): | ||
super().__init__(hf_ds, batch_size) | ||
self._dp_rank = dp_rank | ||
self._rank_id = f"dp_rank_{dp_rank}" | ||
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading | ||
# raise error if dp_word_size does not match. | ||
self._world_size = world_size | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
# Store state only for dp rank to avoid replicating the same state across other dimensions | ||
return { | ||
self._rank_id: pickle.dumps(super().state_dict()), | ||
"world_size": self._world_size, | ||
} | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
# State being empty is valid | ||
if not state_dict: | ||
return | ||
|
||
if self._rank_id not in state_dict: | ||
logger.warning( | ||
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}" | ||
) | ||
return | ||
assert ( | ||
self._world_size == state_dict["world_size"] | ||
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet." | ||
super().load_state_dict(pickle.loads(state_dict[self._rank_id])) | ||
|
||
|
||
def build_hf_data_loader( | ||
def build_hf_dataloader( | ||
dataset_name: str, | ||
dataset_path: Optional[str], | ||
tokenizer: Tokenizer, | ||
tokenizer_path: str, | ||
batch_size: int, | ||
seq_len: int, | ||
world_size: int, | ||
rank: int, | ||
dp_rank: int, | ||
dp_world_size: int, | ||
infinite: bool = True, | ||
): | ||
) -> DPDataLoader: | ||
"""Build a data loader for HuggingFace datasets.""" | ||
tokenizer = build_tokenizer("tiktoken", tokenizer_path) | ||
|
||
hf_ds = HuggingFaceDataset( | ||
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite | ||
dataset_name=dataset_name, | ||
dataset_path=dataset_path, | ||
tokenizer=tokenizer, | ||
seq_len=seq_len, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
infinite=infinite, | ||
) | ||
|
||
return DPDataLoader( | ||
dataset=hf_ds, | ||
tokenizer=tokenizer, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
batch_size=batch_size, | ||
) | ||
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not completely convincing to me if this is the right basic protocol:
I feel that as along as it's an iterator, it's good enough. The things returned from the iterator don't need to be
input_ids, labels
. For sequence masking / multimodal, a dataloader needs to return more, e.g. mask, images, etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think item 2 and 3 are reasonable. And I actually think we can remove them. But what's the point of not supporting checkpointing? For torchdata, which is a generic dataset library, it makes sense to have a very basic dataloader class. TorchTitan is a distributed training library, I don't see a reason why supporting checkpointing during training is not a must.
Our checkpoint manager also assume dataloader to be Stateful. Removing Stateful is too relaxed, imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm more thinking from an experimentation platform perspective.
E.g. I'm a researcher working on a new type of attention. I have some data I'd like to load; I want to look at throughput gain; but I don't care about fault tolerance.