Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to customize dataloader #836

Open
wants to merge 6 commits into
base: gh/fegin/11/base
Choose a base branch
from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Feb 12, 2025

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 12, 2025
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested review from fduwjj and tianyu-l February 12, 2025 08:25
[ghstack-poisoned]
[ghstack-poisoned]
# build dataloader
data_loader = build_hf_data_loader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

@@ -88,15 +88,15 @@ def __init__(
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this is indeed needed. When tracing the code.. I need to double check its definition in the caller, although the comment specified that.

...


class DPDataLoader(StatefulDataLoader, BaseDataLoader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know but not really a big fan of this name. Maybe ParallelAwareDataloader?

[ghstack-poisoned]
Comment on lines +21 to +32
@dataclass
class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.

This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
``state_dict()`` and ``load_state_dict()``.
"""

tokenizer: Tokenizer
dp_rank: int
dp_world_size: int
batch_size: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely convincing to me if this is the right basic protocol:

  1. What if people don't care if it supports checkpointing or not
  2. what if people don't need it to be aware of DP ranks
  3. Do all data loaders need to perform tokenization?

I feel that as along as it's an iterator, it's good enough. The things returned from the iterator don't need to be input_ids, labels. For sequence masking / multimodal, a dataloader needs to return more, e.g. mask, images, etc.

Copy link
Contributor Author

@fegin fegin Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think item 2 and 3 are reasonable. And I actually think we can remove them. But what's the point of not supporting checkpointing? For torchdata, which is a generic dataset library, it makes sense to have a very basic dataloader class. TorchTitan is a distributed training library, I don't see a reason why supporting checkpointing during training is not a must.

Our checkpoint manager also assume dataloader to be Stateful. Removing Stateful is too relaxed, imo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm more thinking from an experimentation platform perspective.
E.g. I'm a researcher working on a new type of attention. I have some data I'd like to load; I want to look at throughput gain; but I don't care about fault tolerance.

)


class DataLoaderBuilder(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: is the signature of this protocol too strict?
Is it too relaxed if we substitute it to something like a Callable from anything to Iterator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one, an alternative in my mind was Callable[[...], BaseDataloader]. This just relaxs the input and users can define whatever they want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants