Skip to content

Commit 47a1028

Browse files
committed
add rampup batch size
1 parent 151fa9f commit 47a1028

File tree

6 files changed

+276
-36
lines changed

6 files changed

+276
-36
lines changed

src/MaxText/configs/base.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,18 @@ packing: True
507507
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
508508
generate_padding_batch_train: False
509509
generate_padding_batch_eval: False
510+
# Rampup batch size, similar to Megatron-LM, see
511+
# https://github.com/NVIDIA/Megatron-LM/blob/2a01637aa54ccdaf7ea9afc1f1b80f58c53d7f3c/megatron/core/num_microbatches_calculator.py#L233-L237
512+
# The ramp-up proceeds in stages from `per_device_batch_size_start` up to
513+
# the final `per_device_batch_size`. For a clean ramp-up, the total range
514+
# (`per_device_batch_size` - `per_device_batch_size_start`)
515+
# should be evenly divisible by batch size increment.
516+
enable_rampup_batch_size: False
517+
per_device_batch_size_start: 4.0
518+
per_device_batch_size_increment: 2.0
519+
# The target number of training samples to process during the ramp-up phase.
520+
# There is no strict rule for this value, it only needs to be positive.
521+
global_rampup_samples: 500
510522

511523
# direct preference optimization (DPO)
512524
use_dpo: False

src/MaxText/data_loader.py

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from jax.experimental import checkify
2121

2222
from MaxText import exceptions
23-
from MaxText import sharding
2423
from MaxText.utils.goodput_utils import (
2524
GoodputEvent,
2625
maybe_record_goodput,
@@ -37,7 +36,6 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
3736
self.goodput_recorder = goodput_recorder
3837
self.data_iterator = data_iterator
3938
self.last_batch = None
40-
self.input_data_shardings = sharding.get_input_data_sharding(config, mesh)
4139

4240
def load_next_batch(self):
4341
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
@@ -47,12 +45,7 @@ def load_next_batch(self):
4745
example_batch = self.last_batch
4846
else:
4947
example_batch = next(self.data_iterator)
50-
# Reshard data from loaded sharding to performant activation sharding
51-
self.last_batch = sharding.maybe_shard_with_name(
52-
example_batch,
53-
self.input_data_shardings,
54-
self.config.shard_mode,
55-
)
48+
self.last_batch = example_batch
5649
self.check_example_batch()
5750
except Exception as e: # pylint: disable=broad-except
5851
if isinstance(e, StopIteration):
@@ -68,3 +61,104 @@ def check_example_batch(self):
6861
# pylint: disable=not-callable
6962
err, _ = jax.jit(jittable_f)(self.last_batch["inputs"][: self.config.global_batch_size_to_train_on, :])
7063
err.throw()
64+
65+
66+
class RampUpDataLoader(DataLoader):
67+
"""
68+
A DataLoader that implements batch size ramp-up.
69+
70+
It dynamically increases the 'global_batch_size_current' in the config
71+
object based on the training step. The rest of the training pipeline
72+
(including the parent's `check_example_batch` and the training step itself)
73+
is assumed to read this config value to determine the logical batch size.
74+
"""
75+
76+
def __init__(self, config, mesh, data_iterator, goodput_recorder):
77+
# Call parent constructor
78+
super().__init__(config, mesh, data_iterator, goodput_recorder)
79+
80+
# Get ramp-up parameters from config, with safe defaults
81+
self.global_batch_size_end = config.global_batch_size_to_load
82+
self.global_batch_size_start = config.global_batch_size_to_load_start
83+
self.increment = config.global_batch_size_to_load_increment
84+
self.samples_per_increment = config.rampup_samples_per_increment_to_load
85+
86+
# Check if ramp-up is active
87+
self.rampup_active = self.global_batch_size_start < self.global_batch_size_end
88+
89+
# State for tracking ramp-up
90+
self.accum_samples = 0
91+
self.global_batch_size_current = self.global_batch_size_start
92+
self.batch_buffer = None
93+
self.buffer_start = 0
94+
95+
def load_next_batch(self):
96+
"""
97+
Updates the batch size based on the schedule and then loads the next
98+
batch using the parent method.
99+
"""
100+
# If ramp-up is not active, just behave like the parent
101+
if not self.rampup_active:
102+
return super().load_next_batch()
103+
104+
# If in rampup phase, we use batch buffer to save data
105+
# Check if it's time to increment the batch size
106+
is_time_to_increment = self.accum_samples >= self.samples_per_increment
107+
108+
if is_time_to_increment:
109+
# Update current batch size and refresh accumulate samples
110+
self.global_batch_size_current += self.increment
111+
self.accum_samples = 0
112+
self.rampup_active = self.global_batch_size_current < self.global_batch_size_end
113+
114+
self.accum_samples += self.global_batch_size_current
115+
slice_start, slice_end = self.buffer_start, self.buffer_start + self.global_batch_size_current
116+
117+
# Load new batch if batch_buffer is None or slice overpast the buffer end
118+
if self.batch_buffer is None:
119+
self.batch_buffer = super().load_next_batch()
120+
slice_start, slice_end = 0, self.global_batch_size_current
121+
122+
if slice_end > self.global_batch_size_end:
123+
old_buffer, self.batch_buffer = self.batch_buffer, super().load_next_batch()
124+
125+
# self.global_batch_size_end is batch_buffer size
126+
def _slice_and_concat(old_data, new_data):
127+
sliced_old_data = jax.lax.dynamic_slice_in_dim(
128+
old_data,
129+
slice_start,
130+
self.global_batch_size_end - slice_start,
131+
axis=0,
132+
)
133+
sliced_new_data = jax.lax.dynamic_slice_in_dim(
134+
new_data,
135+
0,
136+
slice_end - self.global_batch_size_end,
137+
axis=0,
138+
)
139+
return jax.lax.concatenate((sliced_old_data, sliced_new_data), dimension=0)
140+
141+
self.buffer_start = slice_end - self.global_batch_size_end
142+
return jax.tree.map(_slice_and_concat, old_buffer, self.batch_buffer)
143+
else:
144+
145+
def _slice(data):
146+
return jax.lax.dynamic_slice_in_dim(
147+
data,
148+
slice_start,
149+
self.global_batch_size_current,
150+
axis=0,
151+
)
152+
153+
self.buffer_start = slice_end
154+
return jax.tree.map(_slice, self.batch_buffer)
155+
156+
157+
def create_dataloader(config, mesh, data_iterator, goodput_recorder):
158+
"""
159+
Create the dataloader
160+
"""
161+
if config.enable_rampup_batch_size:
162+
return RampUpDataLoader(config, mesh, data_iterator, goodput_recorder)
163+
else:
164+
return DataLoader(config, mesh, data_iterator, goodput_recorder)

src/MaxText/metric_logger.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,22 @@ def log_metrics(self, metrics, step, is_training):
105105
"""Logs metrics via max_logging."""
106106
if is_training:
107107
loss = metrics["scalar"]["learning/loss"]
108-
log_message = (
109-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
110-
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
111-
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
112-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
113-
f"loss: {loss:.3f}"
114-
)
108+
# Do not show flops and tokens during batch size rampup
109+
if step >= self.config.rampup_end_step:
110+
log_message = (
111+
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
112+
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
113+
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
114+
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
115+
f"loss: {loss:.3f}"
116+
)
117+
else:
118+
log_message = (
119+
"[Rampup Batch Size Phase]: "
120+
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
121+
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
122+
f"loss: {loss:.3f}"
123+
)
115124

116125
if self.config.mtp_num_layers > 0:
117126
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
@@ -213,15 +222,16 @@ def buffer_and_write_train_metrics(self, metrics, step, step_time_delta):
213222
def record_train_metrics(self, metrics, step, step_time):
214223
"""Records training metrics for the current step."""
215224
metrics["scalar"].update({"perf/step_time_seconds": step_time})
216-
metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]})
217-
metrics["scalar"].update(
218-
{"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)}
219-
)
220-
metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]})
221-
metrics["scalar"].update(
222-
{"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)}
223-
)
224225
metrics["scalar"].update({"learning/current_learning_rate": self.learning_rate_schedule(step)})
226+
if step >= self.config.rampup_end_step:
227+
metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]})
228+
metrics["scalar"].update(
229+
{"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)}
230+
)
231+
metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]})
232+
metrics["scalar"].update(
233+
{"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)}
234+
)
225235
if self.performance_metric_queue:
226236
self.performance_metric_queue.put(step_time)
227237

src/MaxText/pyconfig.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max
194194
raise ValueError("We currently don't support vocab tiling on NNX module.")
195195

196196

197+
def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
198+
assert batch_size_start > 0, f"per_device_batch_size_start should be positive, got {batch_size_start}."
199+
assert batch_size_increment > 0, f"per_device_batch_size_increment should be positive, got {batch_size_increment}."
200+
assert global_rampup_samples > 0, f"global_rampup_samples should be positive, got {global_rampup_samples}."
201+
diff_batch_size = batch_size_end - batch_size_start
202+
assert diff_batch_size > 0, (
203+
"per_device_batch_size must be greater than per_device_batch_size_start. "
204+
f"get batch size is {batch_size_end} and batch size start is {batch_size_start}."
205+
)
206+
assert diff_batch_size % batch_size_increment == 0, (
207+
"Expect rampup batch size change divisible by batch size increment."
208+
f"Got per_device_batch_size={batch_size_end} and per_device_batch_size_start={batch_size_start}."
209+
)
210+
211+
197212
def validate_keys(keys):
198213
validate_attention_kernel(keys["attention"])
199214
validate_attention_type(keys["attention_type"])
@@ -212,6 +227,13 @@ def validate_keys(keys):
212227
validate_vocab_tiling(
213228
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
214229
)
230+
if keys["enable_rampup_batch_size"]:
231+
validate_rampup_batch_size(
232+
keys["per_device_batch_size_start"],
233+
keys["per_device_batch_size"],
234+
keys["per_device_batch_size_increment"],
235+
keys["global_rampup_samples"],
236+
)
215237

216238
# TODO remove after b/435512699 resolved
217239
if keys["context_parallel_size"] > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk":
@@ -706,6 +728,43 @@ def user_init(raw_keys):
706728
raw_keys["gradient_accumulation_steps"],
707729
)
708730

731+
# Initialize starting global batch size and global increments if rampup batch
732+
# size is enabled
733+
if raw_keys["enable_rampup_batch_size"]:
734+
(
735+
raw_keys["global_batch_size_to_load_start"],
736+
raw_keys["global_batch_size_to_train_on_start"],
737+
raw_keys["micro_batch_size_to_train_on_start"],
738+
) = calculate_global_batch_sizes(
739+
raw_keys["per_device_batch_size_start"],
740+
raw_keys["expansion_factor_real_data"],
741+
get_num_target_devices(raw_keys),
742+
raw_keys["gradient_accumulation_steps"],
743+
)
744+
745+
(
746+
raw_keys["global_batch_size_to_load_increment"],
747+
raw_keys["global_batch_size_to_train_on_increment"],
748+
raw_keys["micro_batch_size_to_train_on_increment"],
749+
) = calculate_global_batch_sizes(
750+
raw_keys["per_device_batch_size_increment"],
751+
raw_keys["expansion_factor_real_data"],
752+
get_num_target_devices(raw_keys),
753+
raw_keys["gradient_accumulation_steps"],
754+
)
755+
756+
(
757+
raw_keys["rampup_samples_per_increment_to_load"],
758+
raw_keys["rampup_end_step"],
759+
) = calculate_rampup_samples_and_steps(
760+
raw_keys["global_batch_size_to_load_start"],
761+
raw_keys["global_batch_size_to_load"],
762+
raw_keys["global_batch_size_to_load_increment"],
763+
raw_keys["global_rampup_samples"],
764+
)
765+
else:
766+
raw_keys["rampup_end_step"] = 0
767+
709768
if raw_keys["eval_per_device_batch_size"] <= 0:
710769
raw_keys["eval_per_device_batch_size"] = raw_keys["per_device_batch_size"]
711770

@@ -1253,6 +1312,27 @@ def calculate_global_batch_sizes(
12531312
return global_batch_size_to_load, global_batch_size_to_train_on, micro_batch_size_to_train_on
12541313

12551314

1315+
def calculate_rampup_samples_and_steps(
1316+
batch_size_start,
1317+
batch_size_end,
1318+
batch_size_increment,
1319+
global_rampup_samples,
1320+
):
1321+
"""Calculate num of samples for each increment and num of steps for batch rampup"""
1322+
diff_batch_size = batch_size_end - batch_size_start
1323+
num_increments = diff_batch_size // batch_size_increment
1324+
rampup_samples_per_increment = global_rampup_samples / num_increments
1325+
total_rampup_steps = 0
1326+
current_batch_size = batch_size_start
1327+
1328+
for _ in range(num_increments):
1329+
steps_for_this_stage = math.ceil(rampup_samples_per_increment / current_batch_size)
1330+
total_rampup_steps += steps_for_this_stage
1331+
current_batch_size += batch_size_increment
1332+
1333+
return rampup_samples_per_increment, total_rampup_steps
1334+
1335+
12561336
def get_num_target_devices(raw_keys):
12571337
# In AOT case compile_topology is set (e.g. is not the empty string), and we determine the
12581338
# number of devices from the compile_topology. In non-AOT settings we simply can use jax.devices().

src/MaxText/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from MaxText import sharding
5454
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
5555
from MaxText.common_types import ShardMode
56-
from MaxText.data_loader import DataLoader
56+
from MaxText.data_loader import create_dataloader
5757
from MaxText.globals import EPS
5858
from MaxText.metric_logger import MetricLogger
5959
from MaxText.utils import gcs_utils
@@ -391,7 +391,7 @@ def train_loop(config, recorder, state=None):
391391

392392
start_step = get_first_step(state) # this is the start_step for training
393393
prof = profiler.Profiler(config, offset_step=start_step)
394-
data_loader = DataLoader(config, mesh, data_iterator, recorder)
394+
data_loader = create_dataloader(config, mesh, data_iterator, recorder)
395395
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
396396

397397
# Write train config params, num model params, and XLA flags to tensorboard
@@ -404,6 +404,12 @@ def train_loop(config, recorder, state=None):
404404

405405
with jax.profiler.StepTraceAnnotation("train", step_num=step):
406406
example_batch = data_loader.load_next_batch()
407+
# Reshard data from loaded sharding to performant activation sharding
408+
example_batch = sharding.maybe_shard_with_name(
409+
example_batch,
410+
sharding.get_input_data_sharding(config, mesh),
411+
shard_mode=config.shard_mode,
412+
)
407413
# pylint: disable=not-callable
408414
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
409415
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):

0 commit comments

Comments
 (0)