-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow users to customize dataloader #836
base: gh/fegin/11/base
Are you sure you want to change the base?
Conversation
# build dataloader | ||
data_loader = build_hf_data_loader( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!
@@ -88,15 +88,15 @@ def __init__( | |||
ds = dataset_loader(path) | |||
|
|||
self.dataset_name = dataset_name | |||
self._data = split_dataset_by_node(ds, rank, world_size) | |||
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, this is indeed needed. When tracing the code.. I need to double check its definition in the caller, although the comment specified that.
... | ||
|
||
|
||
class DPDataLoader(StatefulDataLoader, BaseDataLoader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know but not really a big fan of this name. Maybe ParallelAwareDataloader
?
@dataclass | ||
class BaseDataLoader(Stateful, ABC): | ||
"""Base class for all dataloaders. | ||
|
||
This is used to enforce that all dataloaders have the methods defined in ``Stateful``, | ||
``state_dict()`` and ``load_state_dict()``. | ||
""" | ||
|
||
tokenizer: Tokenizer | ||
dp_rank: int | ||
dp_world_size: int | ||
batch_size: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not completely convincing to me if this is the right basic protocol:
- What if people don't care if it supports checkpointing or not
- what if people don't need it to be aware of DP ranks
- Do all data loaders need to perform tokenization?
I feel that as along as it's an iterator, it's good enough. The things returned from the iterator don't need to be input_ids, labels
. For sequence masking / multimodal, a dataloader needs to return more, e.g. mask, images, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think item 2 and 3 are reasonable. And I actually think we can remove them. But what's the point of not supporting checkpointing? For torchdata, which is a generic dataset library, it makes sense to have a very basic dataloader class. TorchTitan is a distributed training library, I don't see a reason why supporting checkpointing during training is not a must.
Our checkpoint manager also assume dataloader to be Stateful. Removing Stateful is too relaxed, imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm more thinking from an experimentation platform perspective.
E.g. I'm a researcher working on a new type of attention. I have some data I'd like to load; I want to look at throughput gain; but I don't care about fault tolerance.
) | ||
|
||
|
||
class DataLoaderBuilder(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: is the signature of this protocol too strict?
Is it too relaxed if we substitute it to something like a Callable
from anything to Iterator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this one, an alternative in my mind was Callable[[...], BaseDataloader]
. This just relaxs the input and users can define whatever they want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me
Stack from ghstack (oldest at bottom):