-
Couldn't load subscription status.
- Fork 418
Add rampup batch size support in MaxText #2535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,6 @@ | |
| from jax.experimental import checkify | ||
|
|
||
| from MaxText import exceptions | ||
| from MaxText import maxtext_utils | ||
| from MaxText.utils.goodput_utils import ( | ||
| GoodputEvent, | ||
| maybe_record_goodput, | ||
|
|
@@ -37,7 +36,6 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder): | |
| self.goodput_recorder = goodput_recorder | ||
| self.data_iterator = data_iterator | ||
| self.last_batch = None | ||
| self.input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh) | ||
|
|
||
| def load_next_batch(self): | ||
| """Loads the next batch. Can keep reusing the same batch for performance reasons.""" | ||
|
|
@@ -47,12 +45,7 @@ def load_next_batch(self): | |
| example_batch = self.last_batch | ||
| else: | ||
| example_batch = next(self.data_iterator) | ||
| # Reshard data from loaded sharding to performant activation sharding | ||
| self.last_batch = maxtext_utils.maybe_shard_with_name( | ||
| example_batch, | ||
| self.input_data_shardings, | ||
| self.config.shard_mode, | ||
| ) | ||
| self.last_batch = example_batch | ||
NuojCheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.check_example_batch() | ||
| except Exception as e: # pylint: disable=broad-except | ||
| if isinstance(e, StopIteration): | ||
|
|
@@ -68,3 +61,67 @@ def check_example_batch(self): | |
| # pylint: disable=not-callable | ||
| err, _ = jax.jit(jittable_f)(self.last_batch["inputs"][: self.config.global_batch_size_to_train_on, :]) | ||
| err.throw() | ||
|
|
||
|
|
||
| class RampUpDataLoader(DataLoader): | ||
| """ | ||
| A DataLoader that implements batch size ramp-up. | ||
|
|
||
| It dynamically increases the 'global_batch_size_current' in the config | ||
| object based on the training step. The rest of the training pipeline | ||
| (including the parent's `check_example_batch` and the training step itself) | ||
| is assumed to read this config value to determine the logical batch size. | ||
| """ | ||
|
|
||
| def __init__(self, config, mesh, data_iterator, goodput_recorder): | ||
| # Call parent constructor | ||
| super().__init__(config, mesh, data_iterator, goodput_recorder) | ||
|
|
||
| # Get ramp-up parameters from config, with safe defaults | ||
| self.global_batch_size_end = config.global_batch_size_to_load | ||
| self.global_batch_size_start = config.global_batch_size_to_load_start | ||
| self.increment = config.global_batch_size_to_load_increment | ||
| self.samples_per_increment = config.rampup_samples_per_increment_to_load | ||
|
|
||
| # Check if ramp-up is active | ||
| self.rampup_active = self.global_batch_size_start < self.global_batch_size_end | ||
|
|
||
| # State for tracking ramp-up | ||
| self.accum_samples = 0 | ||
| self.global_batch_size_current = self.global_batch_size_start | ||
|
|
||
| def load_next_batch(self): | ||
| """ | ||
| Updates the batch size based on the schedule and then loads the next | ||
| batch using the parent method. | ||
| """ | ||
| # If ramp-up is not active, just behave like the parent | ||
| if not self.rampup_active: | ||
| return super().load_next_batch() | ||
|
|
||
| # Check if it's time to increment the batch size | ||
| is_time_to_increment = self.accum_samples >= self.samples_per_increment | ||
|
|
||
| if is_time_to_increment: | ||
| # Update current batch size and refresh accumulate samples | ||
| self.global_batch_size_current += self.increment | ||
NuojCheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.accum_samples = 0 | ||
| self.rampup_active = self.global_batch_size_current < self.global_batch_size_end | ||
|
|
||
| def _slice(data): | ||
| # When rampup batch size is enabled, we take a partial slice of data and throw others | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got the impression that customer's workload may be sensitive to data. Could we confirm with them if skipping data is ok? Otherwise, we can save the unused data in a "cache" to avoid skipping. |
||
| return jax.lax.dynamic_slice_in_dim(data, 0, self.global_batch_size_current, axis=0) | ||
|
|
||
| self.accum_samples += self.global_batch_size_current | ||
|
|
||
| return jax.tree.map(_slice, super().load_next_batch()) | ||
|
|
||
|
|
||
| def create_dataloader(config, mesh, data_iterator, goodput_recorder): | ||
| """ | ||
| Create the dataloader | ||
| """ | ||
| if config.enable_rampup_batch_size: | ||
| return RampUpDataLoader(config, mesh, data_iterator, goodput_recorder) | ||
| else: | ||
| return DataLoader(config, mesh, data_iterator, goodput_recorder) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,13 +105,21 @@ def log_metrics(self, metrics, step, is_training): | |
| """Logs metrics via max_logging.""" | ||
| if is_training: | ||
| loss = metrics["scalar"]["learning/loss"] | ||
| log_message = ( | ||
| f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " | ||
| f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " | ||
| f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " | ||
| f"total_weights: {metrics['scalar']['learning/total_weights']}, " | ||
| f"loss: {loss:.3f}" | ||
| ) | ||
| # Do not show flops and tokens during batch size rampup | ||
NuojCheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if step >= self.config.rampup_end_step: | ||
| log_message = ( | ||
| f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " | ||
| f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " | ||
| f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " | ||
| f"total_weights: {metrics['scalar']['learning/total_weights']}, " | ||
| f"loss: {loss:.3f}" | ||
| ) | ||
| else: | ||
| log_message = ( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about adding some message to indicate it's in the rampup phase? |
||
| f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " | ||
| f"total_weights: {metrics['scalar']['learning/total_weights']}, " | ||
| f"loss: {loss:.3f}" | ||
| ) | ||
|
|
||
| if self.config.mtp_num_layers > 0: | ||
| mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0) | ||
|
|
@@ -213,15 +221,16 @@ def buffer_and_write_train_metrics(self, metrics, step, step_time_delta): | |
| def record_train_metrics(self, metrics, step, step_time): | ||
| """Records training metrics for the current step.""" | ||
| metrics["scalar"].update({"perf/step_time_seconds": step_time}) | ||
| metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]}) | ||
| metrics["scalar"].update( | ||
| {"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)} | ||
| ) | ||
| metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]}) | ||
| metrics["scalar"].update( | ||
| {"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)} | ||
| ) | ||
| metrics["scalar"].update({"learning/current_learning_rate": self.learning_rate_schedule(step)}) | ||
| if step >= self.config.rampup_end_step: | ||
| metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]}) | ||
| metrics["scalar"].update( | ||
| {"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)} | ||
| ) | ||
| metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]}) | ||
| metrics["scalar"].update( | ||
| {"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)} | ||
| ) | ||
| if self.performance_metric_queue: | ||
| self.performance_metric_queue.put(step_time) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like rampup_samples means global samples? We can add a comment to clarify