Skip to content

Commit 9143004

Browse files
committed
add rampup batch size
1 parent 86937a7 commit 9143004

File tree

6 files changed

+281
-36
lines changed

6 files changed

+281
-36
lines changed

src/MaxText/configs/base.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,18 @@ packing: True
512512
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
513513
generate_padding_batch_train: False
514514
generate_padding_batch_eval: False
515+
# Rampup batch size, similar to Megatron-LM, see
516+
# https://github.com/NVIDIA/Megatron-LM/blob/2a01637aa54ccdaf7ea9afc1f1b80f58c53d7f3c/megatron/core/num_microbatches_calculator.py#L233-L237
517+
# The ramp-up proceeds in stages from `per_device_batch_size_start` up to
518+
# the final `per_device_batch_size`. For a clean ramp-up, the total range
519+
# (`per_device_batch_size` - `per_device_batch_size_start`)
520+
# should be evenly divisible by batch size increment.
521+
enable_rampup_batch_size: False
522+
per_device_batch_size_start: 4.0
523+
per_device_batch_size_increment: 2.0
524+
# The target number of training samples to process during the ramp-up phase.
525+
# There is no strict rule for this value, it only needs to be positive.
526+
global_rampup_samples: 500
515527

516528
# direct preference optimization (DPO)
517529
use_dpo: False

src/MaxText/data_loader.py

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from jax.experimental import checkify
2121

2222
from MaxText import exceptions
23-
from MaxText import sharding
23+
from MaxText import max_logging
2424
from MaxText.utils.goodput_utils import (
2525
GoodputEvent,
2626
maybe_record_goodput,
@@ -42,7 +42,6 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
4242
else:
4343
self.data_iterator = data_iterator
4444
self.last_batch = None
45-
self.input_data_shardings = sharding.get_input_data_sharding(config, mesh)
4645

4746
def update_data_iterator(self):
4847
"""Update to the next data iterator in the list, if applicable."""
@@ -59,12 +58,7 @@ def load_next_batch(self):
5958
else:
6059
example_batch = next(self.data_iterator)
6160
self.update_data_iterator()
62-
# Reshard data from loaded sharding to performant activation sharding
63-
self.last_batch = sharding.maybe_shard_with_name(
64-
example_batch,
65-
self.input_data_shardings,
66-
self.config.shard_mode,
67-
)
61+
self.last_batch = example_batch
6862
self.check_example_batch()
6963
except Exception as e: # pylint: disable=broad-except
7064
if isinstance(e, StopIteration):
@@ -80,3 +74,108 @@ def check_example_batch(self):
8074
# pylint: disable=not-callable
8175
err, _ = jax.jit(jittable_f)(self.last_batch["inputs"][: self.config.global_batch_size_to_train_on, :])
8276
err.throw()
77+
78+
79+
class RampUpDataLoader(DataLoader):
80+
"""
81+
A DataLoader that implements batch size ramp-up.
82+
83+
It dynamically increases the 'global_batch_size_current' in the config
84+
object based on the training step. The rest of the training pipeline
85+
(including the parent's `check_example_batch` and the training step itself)
86+
is assumed to read this config value to determine the logical batch size.
87+
"""
88+
89+
def __init__(self, config, mesh, data_iterator, goodput_recorder):
90+
# Call parent constructor
91+
super().__init__(config, mesh, data_iterator, goodput_recorder)
92+
93+
# Get ramp-up parameters from config, with safe defaults
94+
self.global_batch_size_end = config.global_batch_size_to_load
95+
self.global_batch_size_start = config.global_batch_size_to_load_start
96+
self.increment = config.global_batch_size_to_load_increment
97+
self.samples_per_increment = config.rampup_samples_per_increment_to_load
98+
99+
# Check if ramp-up is active
100+
self.rampup_active = self.global_batch_size_start < self.global_batch_size_end
101+
102+
# State for tracking ramp-up
103+
self.accum_samples = 0
104+
self.global_batch_size_current = self.global_batch_size_start
105+
self.batch_buffer = None
106+
self.buffer_start = 0
107+
108+
def load_next_batch(self):
109+
"""
110+
Updates the batch size based on the schedule and then loads the next
111+
batch using the parent method.
112+
"""
113+
# If ramp-up is not active, just behave like the parent
114+
if not self.rampup_active:
115+
return super().load_next_batch()
116+
117+
# If in rampup phase, we use batch buffer to save data
118+
# Check if it's time to increment the batch size
119+
is_time_to_increment = self.accum_samples >= self.samples_per_increment
120+
121+
if is_time_to_increment:
122+
# Update current batch size and refresh accumulate samples
123+
max_logging.log(
124+
f"Global batch size increments from {self.global_batch_size_current}"
125+
f" to {self.global_batch_size_current + self.increment}"
126+
)
127+
self.global_batch_size_current += self.increment
128+
self.accum_samples = 0
129+
self.rampup_active = self.global_batch_size_current < self.global_batch_size_end
130+
131+
self.accum_samples += self.global_batch_size_current
132+
slice_start, slice_end = self.buffer_start, self.buffer_start + self.global_batch_size_current
133+
134+
# Load new batch if batch_buffer is None or slice overpast the buffer end
135+
if self.batch_buffer is None:
136+
self.batch_buffer = super().load_next_batch()
137+
slice_start, slice_end = 0, self.global_batch_size_current
138+
139+
if slice_end > self.global_batch_size_end:
140+
old_buffer, self.batch_buffer = self.batch_buffer, super().load_next_batch()
141+
142+
# self.global_batch_size_end is batch_buffer size
143+
def _slice_and_concat(old_data, new_data):
144+
sliced_old_data = jax.lax.dynamic_slice_in_dim(
145+
old_data,
146+
slice_start,
147+
self.global_batch_size_end - slice_start,
148+
axis=0,
149+
)
150+
sliced_new_data = jax.lax.dynamic_slice_in_dim(
151+
new_data,
152+
0,
153+
slice_end - self.global_batch_size_end,
154+
axis=0,
155+
)
156+
return jax.lax.concatenate((sliced_old_data, sliced_new_data), dimension=0)
157+
158+
self.buffer_start = slice_end - self.global_batch_size_end
159+
return jax.tree.map(_slice_and_concat, old_buffer, self.batch_buffer)
160+
else:
161+
162+
def _slice(data):
163+
return jax.lax.dynamic_slice_in_dim(
164+
data,
165+
slice_start,
166+
self.global_batch_size_current,
167+
axis=0,
168+
)
169+
170+
self.buffer_start = slice_end
171+
return jax.tree.map(_slice, self.batch_buffer)
172+
173+
174+
def create_dataloader(config, mesh, data_iterator, goodput_recorder):
175+
"""
176+
Create the dataloader
177+
"""
178+
if config.enable_rampup_batch_size:
179+
return RampUpDataLoader(config, mesh, data_iterator, goodput_recorder)
180+
else:
181+
return DataLoader(config, mesh, data_iterator, goodput_recorder)

src/MaxText/metric_logger.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,22 @@ def log_metrics(self, metrics, step, is_training):
105105
"""Logs metrics via max_logging."""
106106
if is_training:
107107
loss = metrics["scalar"]["learning/loss"]
108-
log_message = (
109-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
110-
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
111-
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
112-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
113-
f"loss: {loss:.3f}"
114-
)
108+
# Do not show flops and tokens during batch size rampup
109+
if step >= self.config.rampup_end_step:
110+
log_message = (
111+
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
112+
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
113+
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
114+
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
115+
f"loss: {loss:.3f}"
116+
)
117+
else:
118+
log_message = (
119+
"[Rampup Batch Size Phase]: "
120+
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
121+
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
122+
f"loss: {loss:.3f}"
123+
)
115124

116125
if self.config.mtp_num_layers > 0:
117126
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
@@ -213,15 +222,16 @@ def buffer_and_write_train_metrics(self, metrics, step, step_time_delta):
213222
def record_train_metrics(self, metrics, step, step_time):
214223
"""Records training metrics for the current step."""
215224
metrics["scalar"].update({"perf/step_time_seconds": step_time})
216-
metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]})
217-
metrics["scalar"].update(
218-
{"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)}
219-
)
220-
metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]})
221-
metrics["scalar"].update(
222-
{"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)}
223-
)
224225
metrics["scalar"].update({"learning/current_learning_rate": self.learning_rate_schedule(step)})
226+
if step >= self.config.rampup_end_step:
227+
metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]})
228+
metrics["scalar"].update(
229+
{"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)}
230+
)
231+
metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]})
232+
metrics["scalar"].update(
233+
{"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)}
234+
)
225235
if self.performance_metric_queue:
226236
self.performance_metric_queue.put(step_time)
227237

src/MaxText/pyconfig.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,21 @@ def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max
193193
raise ValueError("We currently don't support vocab tiling on NNX module.")
194194

195195

196+
def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
197+
assert batch_size_start > 0, f"per_device_batch_size_start should be positive, got {batch_size_start}."
198+
assert batch_size_increment > 0, f"per_device_batch_size_increment should be positive, got {batch_size_increment}."
199+
assert global_rampup_samples > 0, f"global_rampup_samples should be positive, got {global_rampup_samples}."
200+
diff_batch_size = batch_size_end - batch_size_start
201+
assert diff_batch_size > 0, (
202+
"per_device_batch_size must be greater than per_device_batch_size_start. "
203+
f"get batch size is {batch_size_end} and batch size start is {batch_size_start}."
204+
)
205+
assert diff_batch_size % batch_size_increment == 0, (
206+
"Expect rampup batch size change divisible by batch size increment."
207+
f"Got per_device_batch_size={batch_size_end} and per_device_batch_size_start={batch_size_start}."
208+
)
209+
210+
196211
def validate_keys(keys):
197212
validate_attention_kernel(keys["attention"])
198213
validate_attention_type(keys["attention_type"])
@@ -211,6 +226,13 @@ def validate_keys(keys):
211226
validate_vocab_tiling(
212227
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
213228
)
229+
if keys["enable_rampup_batch_size"]:
230+
validate_rampup_batch_size(
231+
keys["per_device_batch_size_start"],
232+
keys["per_device_batch_size"],
233+
keys["per_device_batch_size_increment"],
234+
keys["global_rampup_samples"],
235+
)
214236

215237
# TODO remove after b/435512699 resolved
216238
if keys["context_parallel_size"] > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk":
@@ -705,6 +727,43 @@ def user_init(raw_keys):
705727
raw_keys["gradient_accumulation_steps"],
706728
)
707729

730+
# Initialize starting global batch size and global increments if rampup batch
731+
# size is enabled
732+
if raw_keys["enable_rampup_batch_size"]:
733+
(
734+
raw_keys["global_batch_size_to_load_start"],
735+
raw_keys["global_batch_size_to_train_on_start"],
736+
raw_keys["micro_batch_size_to_train_on_start"],
737+
) = calculate_global_batch_sizes(
738+
raw_keys["per_device_batch_size_start"],
739+
raw_keys["expansion_factor_real_data"],
740+
get_num_target_devices(raw_keys),
741+
raw_keys["gradient_accumulation_steps"],
742+
)
743+
744+
(
745+
raw_keys["global_batch_size_to_load_increment"],
746+
raw_keys["global_batch_size_to_train_on_increment"],
747+
raw_keys["micro_batch_size_to_train_on_increment"],
748+
) = calculate_global_batch_sizes(
749+
raw_keys["per_device_batch_size_increment"],
750+
raw_keys["expansion_factor_real_data"],
751+
get_num_target_devices(raw_keys),
752+
raw_keys["gradient_accumulation_steps"],
753+
)
754+
755+
(
756+
raw_keys["rampup_samples_per_increment_to_load"],
757+
raw_keys["rampup_end_step"],
758+
) = calculate_rampup_samples_and_steps(
759+
raw_keys["global_batch_size_to_load_start"],
760+
raw_keys["global_batch_size_to_load"],
761+
raw_keys["global_batch_size_to_load_increment"],
762+
raw_keys["global_rampup_samples"],
763+
)
764+
else:
765+
raw_keys["rampup_end_step"] = 0
766+
708767
if raw_keys["eval_per_device_batch_size"] <= 0:
709768
raw_keys["eval_per_device_batch_size"] = raw_keys["per_device_batch_size"]
710769

@@ -1252,6 +1311,27 @@ def calculate_global_batch_sizes(
12521311
return global_batch_size_to_load, global_batch_size_to_train_on, micro_batch_size_to_train_on
12531312

12541313

1314+
def calculate_rampup_samples_and_steps(
1315+
batch_size_start,
1316+
batch_size_end,
1317+
batch_size_increment,
1318+
global_rampup_samples,
1319+
):
1320+
"""Calculate num of samples for each increment and num of steps for batch rampup"""
1321+
diff_batch_size = batch_size_end - batch_size_start
1322+
num_increments = diff_batch_size // batch_size_increment
1323+
rampup_samples_per_increment = global_rampup_samples / num_increments
1324+
total_rampup_steps = 0
1325+
current_batch_size = batch_size_start
1326+
1327+
for _ in range(num_increments):
1328+
steps_for_this_stage = math.ceil(rampup_samples_per_increment / current_batch_size)
1329+
total_rampup_steps += steps_for_this_stage
1330+
current_batch_size += batch_size_increment
1331+
1332+
return rampup_samples_per_increment, total_rampup_steps
1333+
1334+
12551335
def get_num_target_devices(raw_keys):
12561336
# In AOT case compile_topology is set (e.g. is not the empty string), and we determine the
12571337
# number of devices from the compile_topology. In non-AOT settings we simply can use jax.devices().

src/MaxText/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from MaxText import sharding
5454
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
5555
from MaxText.common_types import ShardMode
56-
from MaxText.data_loader import DataLoader
56+
from MaxText.data_loader import create_dataloader
5757
from MaxText.globals import EPS
5858
from MaxText.metric_logger import MetricLogger
5959
from MaxText.utils import gcs_utils
@@ -391,7 +391,7 @@ def train_loop(config, recorder, state=None):
391391

392392
start_step = get_first_step(state) # this is the start_step for training
393393
prof = profiler.Profiler(config, offset_step=start_step)
394-
data_loader = DataLoader(config, mesh, data_iterator, recorder)
394+
data_loader = create_dataloader(config, mesh, data_iterator, recorder)
395395
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
396396

397397
# Write train config params, num model params, and XLA flags to tensorboard
@@ -404,6 +404,12 @@ def train_loop(config, recorder, state=None):
404404

405405
with jax.profiler.StepTraceAnnotation("train", step_num=step):
406406
example_batch = data_loader.load_next_batch()
407+
# Reshard data from loaded sharding to performant activation sharding
408+
example_batch = sharding.maybe_shard_with_name(
409+
example_batch,
410+
sharding.get_input_data_sharding(config, mesh),
411+
shard_mode=config.shard_mode,
412+
)
407413
# pylint: disable=not-callable
408414
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
409415
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):

0 commit comments

Comments
 (0)