Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,18 @@ packing: True
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
generate_padding_batch_train: False
generate_padding_batch_eval: False
# Rampup batch size, similar to Megatron-LM, see
# https://github.com/NVIDIA/Megatron-LM/blob/2a01637aa54ccdaf7ea9afc1f1b80f58c53d7f3c/megatron/core/num_microbatches_calculator.py#L233-L237
# The ramp-up proceeds in stages from `per_device_batch_size_start` up to
# the final `per_device_batch_size`. For a clean ramp-up, the total range
# (`per_device_batch_size` - `per_device_batch_size_start`)
# should be evenly divisible by batch size increment.
enable_rampup_batch_size: False
per_device_batch_size_start: 4.0
per_device_batch_size_increment: 2.0
# The target number of training samples to process during the ramp-up phase.
# There is no strict rule for this value, it only needs to be positive.
rampup_samples: 500
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like rampup_samples means global samples? We can add a comment to clarify


# direct preference optimization (DPO)
use_dpo: False
Expand Down
73 changes: 65 additions & 8 deletions src/MaxText/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from jax.experimental import checkify

from MaxText import exceptions
from MaxText import maxtext_utils
from MaxText.utils.goodput_utils import (
GoodputEvent,
maybe_record_goodput,
Expand All @@ -37,7 +36,6 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
self.goodput_recorder = goodput_recorder
self.data_iterator = data_iterator
self.last_batch = None
self.input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

def load_next_batch(self):
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
Expand All @@ -47,12 +45,7 @@ def load_next_batch(self):
example_batch = self.last_batch
else:
example_batch = next(self.data_iterator)
# Reshard data from loaded sharding to performant activation sharding
self.last_batch = maxtext_utils.maybe_shard_with_name(
example_batch,
self.input_data_shardings,
self.config.shard_mode,
)
self.last_batch = example_batch
self.check_example_batch()
except Exception as e: # pylint: disable=broad-except
if isinstance(e, StopIteration):
Expand All @@ -68,3 +61,67 @@ def check_example_batch(self):
# pylint: disable=not-callable
err, _ = jax.jit(jittable_f)(self.last_batch["inputs"][: self.config.global_batch_size_to_train_on, :])
err.throw()


class RampUpDataLoader(DataLoader):
"""
A DataLoader that implements batch size ramp-up.

It dynamically increases the 'global_batch_size_current' in the config
object based on the training step. The rest of the training pipeline
(including the parent's `check_example_batch` and the training step itself)
is assumed to read this config value to determine the logical batch size.
"""

def __init__(self, config, mesh, data_iterator, goodput_recorder):
# Call parent constructor
super().__init__(config, mesh, data_iterator, goodput_recorder)

# Get ramp-up parameters from config, with safe defaults
self.global_batch_size_end = config.global_batch_size_to_load
self.global_batch_size_start = config.global_batch_size_to_load_start
self.increment = config.global_batch_size_to_load_increment
self.samples_per_increment = config.rampup_samples_per_increment_to_load

# Check if ramp-up is active
self.rampup_active = self.global_batch_size_start < self.global_batch_size_end

# State for tracking ramp-up
self.accum_samples = 0
self.global_batch_size_current = self.global_batch_size_start

def load_next_batch(self):
"""
Updates the batch size based on the schedule and then loads the next
batch using the parent method.
"""
# If ramp-up is not active, just behave like the parent
if not self.rampup_active:
return super().load_next_batch()

# Check if it's time to increment the batch size
is_time_to_increment = self.accum_samples >= self.samples_per_increment

if is_time_to_increment:
# Update current batch size and refresh accumulate samples
self.global_batch_size_current += self.increment
self.accum_samples = 0
self.rampup_active = self.global_batch_size_current < self.global_batch_size_end

def _slice(data):
# When rampup batch size is enabled, we take a partial slice of data and throw others
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got the impression that customer's workload may be sensitive to data. Could we confirm with them if skipping data is ok? Otherwise, we can save the unused data in a "cache" to avoid skipping.

return jax.lax.dynamic_slice_in_dim(data, 0, self.global_batch_size_current, axis=0)

self.accum_samples += self.global_batch_size_current

return jax.tree.map(_slice, super().load_next_batch())


def create_dataloader(config, mesh, data_iterator, goodput_recorder):
"""
Create the dataloader
"""
if config.enable_rampup_batch_size:
return RampUpDataLoader(config, mesh, data_iterator, goodput_recorder)
else:
return DataLoader(config, mesh, data_iterator, goodput_recorder)
39 changes: 24 additions & 15 deletions src/MaxText/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,21 @@ def log_metrics(self, metrics, step, is_training):
"""Logs metrics via max_logging."""
if is_training:
loss = metrics["scalar"]["learning/loss"]
log_message = (
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {loss:.3f}"
)
# Do not show flops and tokens during batch size rampup
if step >= self.config.rampup_end_step:
log_message = (
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {loss:.3f}"
)
else:
log_message = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding some message to indicate it's in the rampup phase?

f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {loss:.3f}"
)

if self.config.mtp_num_layers > 0:
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
Expand Down Expand Up @@ -213,15 +221,16 @@ def buffer_and_write_train_metrics(self, metrics, step, step_time_delta):
def record_train_metrics(self, metrics, step, step_time):
"""Records training metrics for the current step."""
metrics["scalar"].update({"perf/step_time_seconds": step_time})
metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]})
metrics["scalar"].update(
{"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)}
)
metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]})
metrics["scalar"].update(
{"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)}
)
metrics["scalar"].update({"learning/current_learning_rate": self.learning_rate_schedule(step)})
if step >= self.config.rampup_end_step:
metrics["scalar"].update({"perf/per_device_tflops": self.metadata[MetadataKey.PER_DEVICE_TFLOPS]})
metrics["scalar"].update(
{"perf/per_device_tflops_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TFLOPS] / step_time)}
)
metrics["scalar"].update({"perf/per_device_tokens": self.metadata[MetadataKey.PER_DEVICE_TOKENS]})
metrics["scalar"].update(
{"perf/per_device_tokens_per_sec": (self.metadata[MetadataKey.PER_DEVICE_TOKENS] / step_time)}
)
if self.performance_metric_queue:
self.performance_metric_queue.put(step_time)

Expand Down
80 changes: 80 additions & 0 deletions src/MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@ def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, rampup_samples):
assert batch_size_start > 0, f"per_device_batch_size_start should be positive, got {batch_size_start}."
assert batch_size_increment > 0, f"per_device_batch_size_increment should be positive, got {batch_size_increment}."
assert rampup_samples > 0, f"rampup_samples should be positive, got {rampup_samples}."
diff_batch_size = batch_size_end - batch_size_start
assert diff_batch_size > 0, (
"Expect global batch size larger than batch size start"
f"get batch size is {batch_size_end} and batch size start is {batch_size_start}."
)
assert diff_batch_size % batch_size_increment == 0, (
"Expect rampup batch size change divisible by batch size increment."
f"get batch size diff {diff_batch_size} but batch size increment {batch_size_increment}."
)


def validate_keys(keys):
validate_attention_kernel(keys["attention"])
validate_attention_type(keys["attention_type"])
Expand All @@ -212,6 +227,13 @@ def validate_keys(keys):
validate_vocab_tiling(
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
)
if keys["enable_rampup_batch_size"]:
validate_rampup_batch_size(
keys["per_device_batch_size_start"],
keys["per_device_batch_size"],
keys["per_device_batch_size_increment"],
keys["rampup_samples"],
)

# TODO remove after b/435512699 resolved
if keys["context_parallel_size"] > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk":
Expand Down Expand Up @@ -706,6 +728,43 @@ def user_init(raw_keys):
raw_keys["gradient_accumulation_steps"],
)

# Initialize starting global batch size and global increments if rampup batch
# size is enabled
if raw_keys["enable_rampup_batch_size"]:
(
raw_keys["global_batch_size_to_load_start"],
raw_keys["global_batch_size_to_train_on_start"],
raw_keys["micro_batch_size_to_train_on_start"],
) = calculate_global_batch_sizes(
raw_keys["per_device_batch_size_start"],
raw_keys["expansion_factor_real_data"],
get_num_target_devices(raw_keys),
raw_keys["gradient_accumulation_steps"],
)

(
raw_keys["global_batch_size_to_load_increment"],
raw_keys["global_batch_size_to_train_on_increment"],
raw_keys["micro_batch_size_to_train_on_increment"],
) = calculate_global_batch_sizes(
raw_keys["per_device_batch_size_increment"],
raw_keys["expansion_factor_real_data"],
get_num_target_devices(raw_keys),
raw_keys["gradient_accumulation_steps"],
)

(
raw_keys["rampup_samples_per_increment_to_load"],
raw_keys["rampup_end_step"],
) = calculate_rampup_samples_and_steps(
raw_keys["global_batch_size_to_load_start"],
raw_keys["global_batch_size_to_load"],
raw_keys["global_batch_size_to_load_increment"],
raw_keys["rampup_samples"],
)
else:
raw_keys["rampup_end_step"] = 0

if raw_keys["eval_per_device_batch_size"] <= 0:
raw_keys["eval_per_device_batch_size"] = raw_keys["per_device_batch_size"]

Expand Down Expand Up @@ -1253,6 +1312,27 @@ def calculate_global_batch_sizes(
return global_batch_size_to_load, global_batch_size_to_train_on, micro_batch_size_to_train_on


def calculate_rampup_samples_and_steps(
batch_size_start,
batch_size_end,
batch_size_increment,
rampup_samples,
):
"""Calculate num of samples for each increment and num of steps for batch rampup"""
diff_batch_size = batch_size_end - batch_size_start
num_increments = diff_batch_size // batch_size_increment
rampup_samples_per_increment = rampup_samples / num_increments
total_rampup_steps = 0
current_batch_size = batch_size_start

for _ in range(num_increments):
steps_for_this_stage = math.ceil(rampup_samples_per_increment / current_batch_size)
total_rampup_steps += steps_for_this_stage
current_batch_size += batch_size_increment

return rampup_samples_per_increment, total_rampup_steps


def get_num_target_devices(raw_keys):
# In AOT case compile_topology is set (e.g. is not the empty string), and we determine the
# number of devices from the compile_topology. In non-AOT settings we simply can use jax.devices().
Expand Down
9 changes: 7 additions & 2 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from MaxText import pyconfig
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
from MaxText.common_types import ShardMode
from MaxText.data_loader import DataLoader
from MaxText.data_loader import create_dataloader
from MaxText.globals import EPS
from MaxText.metric_logger import MetricLogger
from MaxText.utils import gcs_utils
Expand Down Expand Up @@ -392,7 +392,7 @@ def train_loop(config, recorder, state=None):

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
data_loader = DataLoader(config, mesh, data_iterator, recorder)
data_loader = create_dataloader(config, mesh, data_iterator, recorder)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
Expand All @@ -405,6 +405,11 @@ def train_loop(config, recorder, state=None):

with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = data_loader.load_next_batch()
# Reshard data from loaded sharding to performant activation sharding
example_batch = jax.lax.with_sharding_constraint(
example_batch,
maxtext_utils.get_input_data_sharding(config, mesh),
)
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
Expand Down
58 changes: 47 additions & 11 deletions tests/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from jax.sharding import Mesh
from jax.experimental import mesh_utils

from MaxText.data_loader import DataLoader
from MaxText.data_loader import DataLoader, RampUpDataLoader
from MaxText import exceptions
from MaxText import pyconfig
from MaxText.globals import MAXTEXT_PKG_DIR
Expand All @@ -34,22 +34,34 @@ class DataLoaderTest(unittest.TestCase):

def setUp(self):
super().setUp()
self.config = self.get_test_config(reuse_example_batch=False)
self.config_reuse_example = self.get_test_config(reuse_example_batch=True)
self.config = self.get_test_config(reuse_example_batch=False, per_device_batch_size=1)
self.config_reuse_example = self.get_test_config(reuse_example_batch=True, per_device_batch_size=1)
self.config_rampup = self.get_test_config(
reuse_example_batch=False,
per_device_batch_size=4.0, # This is the 'end' batch size
enable_rampup_batch_size=True,
per_device_batch_size_start=1.0,
per_device_batch_size_increment=1.0,
rampup_samples=12,
)
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.mock_data_iterator = MagicMock()

def get_test_config(self, reuse_example_batch):
def get_test_config(self, reuse_example_batch, **kwargs):
"""Generate config for tests"""
args = {
"run_name": "test",
"mesh_axes": ["data"],
"logical_axis_rules": [["batch", "data"]],
"data_sharding": ["data"],
"enable_checkpointing": False,
"reuse_example_batch": reuse_example_batch,
}
args.update(kwargs)
return pyconfig.initialize(
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
per_device_batch_size=1,
run_name="test",
mesh_axes=["data"],
logical_axis_rules=[["batch", "data"]],
data_sharding=["data"],
enable_checkpointing=False,
reuse_example_batch=reuse_example_batch,
**args,
)

def test_load_next_batch_success(self):
Expand Down Expand Up @@ -117,6 +129,30 @@ def test_load_next_batch_throws_exception(self):
_ = data_loader.load_next_batch()
self.assertTrue(str(e.exception).startswith("You may have run out of training data."))

def test_rampup_loader_initial_batch_is_sliced(self):
"""Tests that RampUpLoader correctly slices and increment."""
# Mock iterator returns a FULL batch (size 4)
full_batch_size = self.config_rampup.global_batch_size_to_load
full_shape = [full_batch_size, self.config_rampup.max_target_length]
full_batch = {"inputs": np.ones(full_shape, dtype=int)}
self.mock_data_iterator.__next__.return_value = full_batch

# Create the RampUpDataLoader
data_loader = RampUpDataLoader(self.config_rampup, self.mesh, self.mock_data_iterator, None)

# Check that the batch is SLICED to the start_bs (size 1)
curr_batch_size = self.config_rampup.global_batch_size_to_load_start
increment_batch_size = self.config_rampup.global_batch_size_to_load_increment

# We only test the first 4 batches, in which the first 3 are rampup batches
for _ in range(4):
batch = data_loader.load_next_batch()
expected_batch_size = (curr_batch_size, self.config_rampup.max_target_length)
self.assertEqual(batch["inputs"].shape, expected_batch_size)
self.assertTrue((batch["inputs"] == 1).all())
if data_loader.rampup_active:
curr_batch_size += increment_batch_size


if __name__ == "__main__":
unittest.main()
Loading