-
Couldn't load subscription status.
- Fork 418
Add rampup batch size support in MaxText #2535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
57ff3e8 to
842193d
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
842193d to
82f5edf
Compare
It seems out of quota for free tier. We are going to update the Tier 1, should be better soon. |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Description
This PR adds support for ramp-up batch size, a feature originally proposed in the GPT-3 paper and implemented in Megatron.
When enabled, the per device batch size starts at a smaller value (
per_device_batch_size_start) and gradually increases (per_device_batch_size_increment) until it reaches the targetper_device_batch_sizeover a specified number oframpup_samples. This can help improve training stability, especially during the initial training phases.This feature introduces four new configuration parameters, which align with the Megatron implementation:
enable_rampup_batch_size: (default:False) Set toTrueto enable the ramp-up feature.per_device_batch_size_start: The per-device batch size to use at the beginning of training.per_device_batch_size_increment: The amount to increase the per-device batch size at each ramp-up step.rampup_samples: The total number of samples to process before reaching the full target batch size.The PR includes the following changes:
RampupDataLoader: Adds a newRampupDataLoaderclass that inherits from the baseDataLoader. Its primary responsibility is to truncate the input data to match the correct ramp-up shape for the current training step.pyconfig.pyto register and validate the new ramp-up configuration parameters.data_loader_tests.pyto verify theRampupDataLoader's slicing and increment logic.FIXES: b/452468482
Tests
New data_loader_test.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.