diff --git a/src/MaxText/data_loader.py b/src/MaxText/data_loader.py index 352b41841..9eadfb5fa 100644 --- a/src/MaxText/data_loader.py +++ b/src/MaxText/data_loader.py @@ -27,12 +27,77 @@ ) +class RampupBatchCalculator: + """ + Calculator to track current batch size given train step + """ + + def __init__(self, config, step_num): + self._verify_inputs(config) + self._init_values(config) + self.num_accum_samples = 0 + + # Compute the number of samples already used given recovered step num + self._recover_states(step_num) + + def _verify_inputs(self, config): + """Verify the rampup batch related inputs.""" + diff_batch_size = config.per_device_batch_size - config.per_device_batch_size_start + if diff_batch_size <= 0: + raise ValueError( + "per_device_batch_size must be greater than per_device_batch_size_start. " + f"get batch size is {config.per_device_batch_size} and " + f"batch size start is {config.per_device_batch_size_start}." + ) + if diff_batch_size % config.per_device_batch_size_increment: + raise ValueError( + "Expect rampup batch size change divisible by batch size increment." + f"Got per_device_batch_size={config.per_device_batch_size} and " + f"per_device_batch_size_start={config.per_device_batch_size_start}." + ) + + def _init_values(self, config): + """Initialize rampup batch related parameters""" + diff_batch_size = config.per_device_batch_size - config.per_device_batch_size_start + num_increments = diff_batch_size // config.per_device_batch_size_increment + self.samples_per_increment = config.global_rampup_samples / num_increments + num_devices = config.num_target_devices + self.global_batch_size_end = num_devices * config.per_device_batch_size + self.global_batch_size_start = num_devices * config.per_device_batch_size_start + self.increment = num_devices * config.per_device_batch_size_increment + self.global_rampup_samples = config.global_rampup_samples + self.global_batch_size_current = self.global_batch_size_start + + def _recover_states(self, step_num): + """Recover the number of samples already used""" + if step_num < 0: return + for _ in range(step_num + 1): + _ = self.update() + return + + def update(self): + self.num_accum_samples += self.global_batch_size_current + # Check if it's time to increment the batch size + is_time_to_increment = self.num_accum_samples >= self.samples_per_increment + if is_time_to_increment: + max_logging.log( + f"Global batch size increments from {self.global_batch_size_current}" + f" to {self.global_batch_size_current + self.increment}. " + f"{self.num_accum_samples} data samples already used." + ) + self.global_batch_size_current += self.increment + self.num_accum_samples = 0 + self.num_accum_samples += self.global_batch_size_current + # return whether rampup phase is active or not + return self.global_batch_size_current < self.global_batch_size_end + + class DataLoader: """ Loads preprocessed data for training. """ - def __init__(self, config, mesh, data_iterator, goodput_recorder): + def __init__(self, config, data_iterator, goodput_recorder): self.config = config self.goodput_recorder = goodput_recorder if isinstance(data_iterator, list): @@ -86,22 +151,13 @@ class RampUpDataLoader(DataLoader): is assumed to read this config value to determine the logical batch size. """ - def __init__(self, config, mesh, data_iterator, goodput_recorder): + def __init__(self, config, data_iterator, goodput_recorder, latest_step): # Call parent constructor - super().__init__(config, mesh, data_iterator, goodput_recorder) - - # Get ramp-up parameters from config, with safe defaults - self.global_batch_size_end = config.global_batch_size_to_load - self.global_batch_size_start = config.global_batch_size_to_load_start - self.increment = config.global_batch_size_to_load_increment - self.samples_per_increment = config.rampup_samples_per_increment_to_load - - # Check if ramp-up is active - self.rampup_active = self.global_batch_size_start < self.global_batch_size_end - - # State for tracking ramp-up - self.accum_samples = 0 - self.global_batch_size_current = self.global_batch_size_start + super().__init__(config, data_iterator, goodput_recorder) + + # Initialize batch size calculator + self.calculator = RampupBatchCalculator(config, latest_step) + self.rampup_active = self.calculator.num_accum_samples < config.global_rampup_samples self.batch_buffer = None self.buffer_start = 0 @@ -114,29 +170,17 @@ def load_next_batch(self): if not self.rampup_active: return super().load_next_batch() - # If in rampup phase, we use batch buffer to save data - # Check if it's time to increment the batch size - is_time_to_increment = self.accum_samples >= self.samples_per_increment - - if is_time_to_increment: - # Update current batch size and refresh accumulate samples - max_logging.log( - f"Global batch size increments from {self.global_batch_size_current}" - f" to {self.global_batch_size_current + self.increment}" - ) - self.global_batch_size_current += self.increment - self.accum_samples = 0 - self.rampup_active = self.global_batch_size_current < self.global_batch_size_end + self.rampup_active = self.calculator.update() - self.accum_samples += self.global_batch_size_current - slice_start, slice_end = self.buffer_start, self.buffer_start + self.global_batch_size_current + slice_start, slice_end = self.buffer_start, self.buffer_start + self.calculator.global_batch_size_current - # Load new batch if batch_buffer is None or slice overpast the buffer end + # Load new batch if batch_buffer is None if self.batch_buffer is None: self.batch_buffer = super().load_next_batch() - slice_start, slice_end = 0, self.global_batch_size_current + slice_start, slice_end = 0, self.calculator.global_batch_size_current - if slice_end > self.global_batch_size_end: + # If the slice end overpast batch end we collect new batch data + if slice_end > self.calculator.global_batch_size_end: old_buffer, self.batch_buffer = self.batch_buffer, super().load_next_batch() # self.global_batch_size_end is batch_buffer size @@ -144,26 +188,25 @@ def _slice_and_concat(old_data, new_data): sliced_old_data = jax.lax.dynamic_slice_in_dim( old_data, slice_start, - self.global_batch_size_end - slice_start, + self.calculator.global_batch_size_end - slice_start, axis=0, ) sliced_new_data = jax.lax.dynamic_slice_in_dim( new_data, 0, - slice_end - self.global_batch_size_end, + slice_end - self.calculator.global_batch_size_end, axis=0, ) return jax.lax.concatenate((sliced_old_data, sliced_new_data), dimension=0) - self.buffer_start = slice_end - self.global_batch_size_end + self.buffer_start = slice_end - self.calculator.global_batch_size_end return jax.tree.map(_slice_and_concat, old_buffer, self.batch_buffer) else: - def _slice(data): return jax.lax.dynamic_slice_in_dim( data, slice_start, - self.global_batch_size_current, + self.calculator.global_batch_size_current, axis=0, ) @@ -171,11 +214,11 @@ def _slice(data): return jax.tree.map(_slice, self.batch_buffer) -def create_dataloader(config, mesh, data_iterator, goodput_recorder): +def create_dataloader(config, data_iterator, goodput_recorder, latest_step=-1): """ Create the dataloader """ if config.enable_rampup_batch_size: - return RampUpDataLoader(config, mesh, data_iterator, goodput_recorder) + return RampUpDataLoader(config, data_iterator, goodput_recorder, latest_step) else: - return DataLoader(config, mesh, data_iterator, goodput_recorder) + return DataLoader(config, data_iterator, goodput_recorder) diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 62d896d3d..80c0902c9 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -758,6 +758,7 @@ def user_init(raw_keys): # This is the first command that initializes the backend - it calls # jax.devices() + raw_keys["num_target_devices"] = get_num_target_devices(raw_keys) ( raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"], @@ -765,7 +766,7 @@ def user_init(raw_keys): ) = calculate_global_batch_sizes( raw_keys["per_device_batch_size"], raw_keys["expansion_factor_real_data"], - get_num_target_devices(raw_keys), + raw_keys["num_target_devices"], raw_keys["gradient_accumulation_steps"], ) @@ -779,7 +780,7 @@ def user_init(raw_keys): ) = calculate_global_batch_sizes( raw_keys["per_device_batch_size_start"], raw_keys["expansion_factor_real_data"], - get_num_target_devices(raw_keys), + raw_keys["num_target_devices"], raw_keys["gradient_accumulation_steps"], ) @@ -790,7 +791,7 @@ def user_init(raw_keys): ) = calculate_global_batch_sizes( raw_keys["per_device_batch_size_increment"], raw_keys["expansion_factor_real_data"], - get_num_target_devices(raw_keys), + raw_keys["num_target_devices"], raw_keys["gradient_accumulation_steps"], ) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index bb48a2c72..0431175df 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -392,7 +392,7 @@ def train_loop(config, recorder, state=None): start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - data_loader = create_dataloader(config, mesh, data_iterator, recorder) + data_loader = create_dataloader(config, data_iterator, recorder, checkpoint_manager.latest_step() if checkpoint_manager else -1) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard