2020from jax .experimental import checkify
2121
2222from MaxText import exceptions
23- from MaxText import sharding
23+ from MaxText import max_logging
2424from MaxText .utils .goodput_utils import (
2525 GoodputEvent ,
2626 maybe_record_goodput ,
@@ -42,7 +42,6 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
4242 else :
4343 self .data_iterator = data_iterator
4444 self .last_batch = None
45- self .input_data_shardings = sharding .get_input_data_sharding (config , mesh )
4645
4746 def update_data_iterator (self ):
4847 """Update to the next data iterator in the list, if applicable."""
@@ -59,12 +58,7 @@ def load_next_batch(self):
5958 else :
6059 example_batch = next (self .data_iterator )
6160 self .update_data_iterator ()
62- # Reshard data from loaded sharding to performant activation sharding
63- self .last_batch = sharding .maybe_shard_with_name (
64- example_batch ,
65- self .input_data_shardings ,
66- self .config .shard_mode ,
67- )
61+ self .last_batch = example_batch
6862 self .check_example_batch ()
6963 except Exception as e : # pylint: disable=broad-except
7064 if isinstance (e , StopIteration ):
@@ -80,3 +74,108 @@ def check_example_batch(self):
8074 # pylint: disable=not-callable
8175 err , _ = jax .jit (jittable_f )(self .last_batch ["inputs" ][: self .config .global_batch_size_to_train_on , :])
8276 err .throw ()
77+
78+
79+ class RampUpDataLoader (DataLoader ):
80+ """
81+ A DataLoader that implements batch size ramp-up.
82+
83+ It dynamically increases the 'global_batch_size_current' in the config
84+ object based on the training step. The rest of the training pipeline
85+ (including the parent's `check_example_batch` and the training step itself)
86+ is assumed to read this config value to determine the logical batch size.
87+ """
88+
89+ def __init__ (self , config , mesh , data_iterator , goodput_recorder ):
90+ # Call parent constructor
91+ super ().__init__ (config , mesh , data_iterator , goodput_recorder )
92+
93+ # Get ramp-up parameters from config, with safe defaults
94+ self .global_batch_size_end = config .global_batch_size_to_load
95+ self .global_batch_size_start = config .global_batch_size_to_load_start
96+ self .increment = config .global_batch_size_to_load_increment
97+ self .samples_per_increment = config .rampup_samples_per_increment_to_load
98+
99+ # Check if ramp-up is active
100+ self .rampup_active = self .global_batch_size_start < self .global_batch_size_end
101+
102+ # State for tracking ramp-up
103+ self .accum_samples = 0
104+ self .global_batch_size_current = self .global_batch_size_start
105+ self .batch_buffer = None
106+ self .buffer_start = 0
107+
108+ def load_next_batch (self ):
109+ """
110+ Updates the batch size based on the schedule and then loads the next
111+ batch using the parent method.
112+ """
113+ # If ramp-up is not active, just behave like the parent
114+ if not self .rampup_active :
115+ return super ().load_next_batch ()
116+
117+ # If in rampup phase, we use batch buffer to save data
118+ # Check if it's time to increment the batch size
119+ is_time_to_increment = self .accum_samples >= self .samples_per_increment
120+
121+ if is_time_to_increment :
122+ # Update current batch size and refresh accumulate samples
123+ max_logging .log (
124+ f"Global batch size increments from { self .global_batch_size_current } "
125+ f" to { self .global_batch_size_current + self .increment } "
126+ )
127+ self .global_batch_size_current += self .increment
128+ self .accum_samples = 0
129+ self .rampup_active = self .global_batch_size_current < self .global_batch_size_end
130+
131+ self .accum_samples += self .global_batch_size_current
132+ slice_start , slice_end = self .buffer_start , self .buffer_start + self .global_batch_size_current
133+
134+ # Load new batch if batch_buffer is None or slice overpast the buffer end
135+ if self .batch_buffer is None :
136+ self .batch_buffer = super ().load_next_batch ()
137+ slice_start , slice_end = 0 , self .global_batch_size_current
138+
139+ if slice_end > self .global_batch_size_end :
140+ old_buffer , self .batch_buffer = self .batch_buffer , super ().load_next_batch ()
141+
142+ # self.global_batch_size_end is batch_buffer size
143+ def _slice_and_concat (old_data , new_data ):
144+ sliced_old_data = jax .lax .dynamic_slice_in_dim (
145+ old_data ,
146+ slice_start ,
147+ self .global_batch_size_end - slice_start ,
148+ axis = 0 ,
149+ )
150+ sliced_new_data = jax .lax .dynamic_slice_in_dim (
151+ new_data ,
152+ 0 ,
153+ slice_end - self .global_batch_size_end ,
154+ axis = 0 ,
155+ )
156+ return jax .lax .concatenate ((sliced_old_data , sliced_new_data ), dimension = 0 )
157+
158+ self .buffer_start = slice_end - self .global_batch_size_end
159+ return jax .tree .map (_slice_and_concat , old_buffer , self .batch_buffer )
160+ else :
161+
162+ def _slice (data ):
163+ return jax .lax .dynamic_slice_in_dim (
164+ data ,
165+ slice_start ,
166+ self .global_batch_size_current ,
167+ axis = 0 ,
168+ )
169+
170+ self .buffer_start = slice_end
171+ return jax .tree .map (_slice , self .batch_buffer )
172+
173+
174+ def create_dataloader (config , mesh , data_iterator , goodput_recorder ):
175+ """
176+ Create the dataloader
177+ """
178+ if config .enable_rampup_batch_size :
179+ return RampUpDataLoader (config , mesh , data_iterator , goodput_recorder )
180+ else :
181+ return DataLoader (config , mesh , data_iterator , goodput_recorder )
0 commit comments