2020from jax .experimental import checkify
2121
2222from MaxText import exceptions
23- from MaxText import sharding
2423from MaxText .utils .goodput_utils import (
2524 GoodputEvent ,
2625 maybe_record_goodput ,
@@ -37,7 +36,6 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
3736 self .goodput_recorder = goodput_recorder
3837 self .data_iterator = data_iterator
3938 self .last_batch = None
40- self .input_data_shardings = sharding .get_input_data_sharding (config , mesh )
4139
4240 def load_next_batch (self ):
4341 """Loads the next batch. Can keep reusing the same batch for performance reasons."""
@@ -47,12 +45,7 @@ def load_next_batch(self):
4745 example_batch = self .last_batch
4846 else :
4947 example_batch = next (self .data_iterator )
50- # Reshard data from loaded sharding to performant activation sharding
51- self .last_batch = sharding .maybe_shard_with_name (
52- example_batch ,
53- self .input_data_shardings ,
54- self .config .shard_mode ,
55- )
48+ self .last_batch = example_batch
5649 self .check_example_batch ()
5750 except Exception as e : # pylint: disable=broad-except
5851 if isinstance (e , StopIteration ):
@@ -68,3 +61,104 @@ def check_example_batch(self):
6861 # pylint: disable=not-callable
6962 err , _ = jax .jit (jittable_f )(self .last_batch ["inputs" ][: self .config .global_batch_size_to_train_on , :])
7063 err .throw ()
64+
65+
66+ class RampUpDataLoader (DataLoader ):
67+ """
68+ A DataLoader that implements batch size ramp-up.
69+
70+ It dynamically increases the 'global_batch_size_current' in the config
71+ object based on the training step. The rest of the training pipeline
72+ (including the parent's `check_example_batch` and the training step itself)
73+ is assumed to read this config value to determine the logical batch size.
74+ """
75+
76+ def __init__ (self , config , mesh , data_iterator , goodput_recorder ):
77+ # Call parent constructor
78+ super ().__init__ (config , mesh , data_iterator , goodput_recorder )
79+
80+ # Get ramp-up parameters from config, with safe defaults
81+ self .global_batch_size_end = config .global_batch_size_to_load
82+ self .global_batch_size_start = config .global_batch_size_to_load_start
83+ self .increment = config .global_batch_size_to_load_increment
84+ self .samples_per_increment = config .rampup_samples_per_increment_to_load
85+
86+ # Check if ramp-up is active
87+ self .rampup_active = self .global_batch_size_start < self .global_batch_size_end
88+
89+ # State for tracking ramp-up
90+ self .accum_samples = 0
91+ self .global_batch_size_current = self .global_batch_size_start
92+ self .batch_buffer = None
93+ self .buffer_start = 0
94+
95+ def load_next_batch (self ):
96+ """
97+ Updates the batch size based on the schedule and then loads the next
98+ batch using the parent method.
99+ """
100+ # If ramp-up is not active, just behave like the parent
101+ if not self .rampup_active :
102+ return super ().load_next_batch ()
103+
104+ # If in rampup phase, we use batch buffer to save data
105+ # Check if it's time to increment the batch size
106+ is_time_to_increment = self .accum_samples >= self .samples_per_increment
107+
108+ if is_time_to_increment :
109+ # Update current batch size and refresh accumulate samples
110+ self .global_batch_size_current += self .increment
111+ self .accum_samples = 0
112+ self .rampup_active = self .global_batch_size_current < self .global_batch_size_end
113+
114+ self .accum_samples += self .global_batch_size_current
115+ slice_start , slice_end = self .buffer_start , self .buffer_start + self .global_batch_size_current
116+
117+ # Load new batch if batch_buffer is None or slice overpast the buffer end
118+ if self .batch_buffer is None :
119+ self .batch_buffer = super ().load_next_batch ()
120+ slice_start , slice_end = 0 , self .global_batch_size_current
121+
122+ if slice_end > self .global_batch_size_end :
123+ old_buffer , self .batch_buffer = self .batch_buffer , super ().load_next_batch ()
124+
125+ # self.global_batch_size_end is batch_buffer size
126+ def _slice_and_concat (old_data , new_data ):
127+ sliced_old_data = jax .lax .dynamic_slice_in_dim (
128+ old_data ,
129+ slice_start ,
130+ self .global_batch_size_end - slice_start ,
131+ axis = 0 ,
132+ )
133+ sliced_new_data = jax .lax .dynamic_slice_in_dim (
134+ new_data ,
135+ 0 ,
136+ slice_end - self .global_batch_size_end ,
137+ axis = 0 ,
138+ )
139+ return jax .lax .concatenate ((sliced_old_data , sliced_new_data ), dimension = 0 )
140+
141+ self .buffer_start = slice_end - self .global_batch_size_end
142+ return jax .tree .map (_slice_and_concat , old_buffer , self .batch_buffer )
143+ else :
144+
145+ def _slice (data ):
146+ return jax .lax .dynamic_slice_in_dim (
147+ data ,
148+ slice_start ,
149+ self .global_batch_size_current ,
150+ axis = 0 ,
151+ )
152+
153+ self .buffer_start = slice_end
154+ return jax .tree .map (_slice , self .batch_buffer )
155+
156+
157+ def create_dataloader (config , mesh , data_iterator , goodput_recorder ):
158+ """
159+ Create the dataloader
160+ """
161+ if config .enable_rampup_batch_size :
162+ return RampUpDataLoader (config , mesh , data_iterator , goodput_recorder )
163+ else :
164+ return DataLoader (config , mesh , data_iterator , goodput_recorder )
0 commit comments