44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import pickle
87from dataclasses import dataclass
9- from typing import Any , Callable , Dict , List , Optional
8+ from typing import Any , Callable , Optional
109
1110import torch
1211
1312from datasets import Dataset , load_dataset
1413from datasets .distributed import split_dataset_by_node
1514from torch .distributed .checkpoint .stateful import Stateful
1615from torch .utils .data import IterableDataset
17- from torchdata .stateful_dataloader import StatefulDataLoader
1816
17+ from torchtitan .dataloader import ParallelAwareDataloader
1918from torchtitan .datasets .tokenizer import Tokenizer
2019from torchtitan .logging import logger
2120
@@ -25,7 +24,7 @@ def _load_c4_dataset(dataset_path: str):
2524 return load_dataset (dataset_path , name = "en" , split = "train" , streaming = True )
2625
2726
28- def _process_c4_text (sample : Dict [str , Any ]) -> str :
27+ def _process_c4_text (sample : dict [str , Any ]) -> str :
2928 """Process C4 dataset sample text."""
3029 return sample ["text" ]
3130
@@ -75,8 +74,8 @@ def __init__(
7574 dataset_path : Optional [str ],
7675 tokenizer : Tokenizer ,
7776 seq_len : int = 2048 ,
78- world_size : int = 1 ,
79- rank : int = 0 ,
77+ dp_rank : int = 0 ,
78+ dp_world_size : int = 1 ,
8079 infinite : bool = False ,
8180 ) -> None :
8281 # Force lowercase for consistent comparison
@@ -88,15 +87,15 @@ def __init__(
8887 ds = dataset_loader (path )
8988
9089 self .dataset_name = dataset_name
91- self ._data = split_dataset_by_node (ds , rank , world_size )
90+ self ._data = split_dataset_by_node (ds , dp_rank , dp_world_size )
9291 self ._tokenizer = tokenizer
9392 self .seq_len = seq_len
9493 self .infinite = infinite
9594 self ._text_processor = text_processor
9695
9796 # Variables for checkpointing
9897 self ._sample_idx = 0
99- self ._all_tokens : List [int ] = []
98+ self ._all_tokens : list [int ] = []
10099
101100 def _get_data_iter (self ):
102101 if isinstance (self ._data , Dataset ) and self ._sample_idx == len (self ._data ):
@@ -142,56 +141,31 @@ def state_dict(self):
142141 return {"token_buffer" : self ._all_tokens , "sample_idx" : self ._sample_idx }
143142
144143
145- class DPAwareDataLoader (StatefulDataLoader , Stateful ):
146- """
147- A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
148- """
149-
150- def __init__ (
151- self , dp_rank : int , hf_ds : IterableDataset , batch_size : int , world_size : int
152- ):
153- super ().__init__ (hf_ds , batch_size )
154- self ._dp_rank = dp_rank
155- self ._rank_id = f"dp_rank_{ dp_rank } "
156- # Data loader resharding is not yet supported, so we need to store the world size to compare during loading
157- # raise error if dp_word_size does not match.
158- self ._world_size = world_size
159-
160- def state_dict (self ) -> Dict [str , Any ]:
161- # Store state only for dp rank to avoid replicating the same state across other dimensions
162- return {
163- self ._rank_id : pickle .dumps (super ().state_dict ()),
164- "world_size" : self ._world_size ,
165- }
166-
167- def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
168- # State being empty is valid
169- if not state_dict :
170- return
171-
172- if self ._rank_id not in state_dict :
173- logger .warning (
174- f"DataLoader state is empty for dp rank { self ._dp_rank } , expected key { self ._rank_id } "
175- )
176- return
177- assert (
178- self ._world_size == state_dict ["world_size" ]
179- ), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
180- super ().load_state_dict (pickle .loads (state_dict [self ._rank_id ]))
181-
182-
183- def build_hf_data_loader (
144+ def build_hf_dataloader (
184145 dataset_name : str ,
185146 dataset_path : Optional [str ],
186147 tokenizer : Tokenizer ,
187148 batch_size : int ,
188149 seq_len : int ,
189- world_size : int ,
190- rank : int ,
150+ dp_rank : int ,
151+ dp_world_size : int ,
191152 infinite : bool = True ,
192- ):
153+ ) -> ParallelAwareDataloader :
193154 """Build a data loader for HuggingFace datasets."""
155+
194156 hf_ds = HuggingFaceDataset (
195- dataset_name , dataset_path , tokenizer , seq_len , world_size , rank , infinite
157+ dataset_name = dataset_name ,
158+ dataset_path = dataset_path ,
159+ tokenizer = tokenizer ,
160+ seq_len = seq_len ,
161+ dp_rank = dp_rank ,
162+ dp_world_size = dp_world_size ,
163+ infinite = infinite ,
164+ )
165+
166+ return ParallelAwareDataloader (
167+ dataset = hf_ds ,
168+ dp_rank = dp_rank ,
169+ dp_world_size = dp_world_size ,
170+ batch_size = batch_size ,
196171 )
197- return DPAwareDataLoader (rank , hf_ds , batch_size = batch_size , world_size = world_size )
0 commit comments