Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 125 additions & 9 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
# pylint: disable=too-many-positional-arguments
import dataclasses
import json
from typing import Optional

from grain._src.core import sharding
from grain._src.python import checkpoint_handlers as grain_checkpoint_handlers
from grain._src.python.dataset import dataset

CheckpointManager = ocp.CheckpointManager
CheckpointManagerOptions = ocp.CheckpointManagerOptions
Expand All @@ -44,6 +51,80 @@
EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager


class MaxtextGrainCheckpointHandler(grain_checkpoint_handlers.CheckpointHandler):
"""A CheckpointHandler that allows specifying process_index and process_count."""
def save(
self,
directory: epath.Path,
# `item` is for backwards compatibility with older Orbax API, see
# https://orbax.readthedocs.io/en/latest/api_refactor.html.
item: Optional[grain_checkpoint_handlers.IteratorType] = None,
args: Any = None,
):
"""Saves the given iterator to the checkpoint in `directory`."""
item = item or args.item # pytype:disable=attribute-error

def save_single_process(item, process_index, process_count):
filename = directory / f"process_{process_index}-of-{process_count}.json"
if isinstance(item, dataset.DatasetIterator):
state = json.dumps(item.get_state(), indent=4)
else:
state = item.get_state().decode()
filename.write_text(state)

if isinstance(item, list):
for local_iterator, process_index, process_count in item:
save_single_process(local_iterator, process_index, process_count)
else:
process_index, process_count = sharding.get_process_index_and_count()
save_single_process(item, process_index, process_count)

def restore(
self,
directory: epath.Path,
item: Optional[grain_checkpoint_handlers.IteratorType] = None,
args: Any = None,
) -> grain_checkpoint_handlers.IteratorType:
"""Restores the given iterator from the checkpoint in `directory`."""
item = item or args.item
process_index = args.process_index if hasattr(args, "process_index") and args.process_index is not None else None
process_count = args.process_count if hasattr(args, "process_count") and args.process_count is not None else None

def restore_single_process(item, process_index, process_count):
filename = directory / f"process_{process_index}-of-{process_count}.json"
if not filename.exists():
raise ValueError(f"File {filename} does not exist.")
state = filename.read_text()
if isinstance(item, dataset.DatasetIterator):
state = json.loads(state)
else:
state = state.encode()
item.set_state(state)
return item

if isinstance(item, list):
restored_items = []
for data_iter, process_idx in zip(item, process_index):
restored_items.append(restore_single_process(data_iter, process_idx, process_count))
return restored_items
else:
if process_index is None or process_count is None:
process_index, process_count = sharding.get_process_index_and_count()
return restore_single_process(item, process_index, process_count)

@ocp.args.register_with_handler(MaxtextGrainCheckpointHandler, for_save=True)
@dataclasses.dataclass
class MaxtextGrainCheckpointSave(ocp.args.CheckpointArgs):
item: Any

@ocp.args.register_with_handler(MaxtextGrainCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class MaxtextGrainCheckpointRestore(ocp.args.CheckpointArgs):
item: Any
process_index: Optional[int | list[int]] = None
process_count: Optional[int] = None


def _load_full_state_from_path(
path,
abstract_unboxed_pre_state,
Expand Down Expand Up @@ -111,10 +192,14 @@ def create_orbax_checkpoint_manager(

max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}")

# Base configuration for all dataset types
item_names = ("items",)
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}

if dataset_type == "grain":
item_names = ("items", "iter")
else:
item_names = ("items",)
item_names += ("iter",)
item_handlers["iter"] = MaxtextGrainCheckpointHandler()

# local storage checkpoint needs parent directory created
p = epath.Path(checkpoint_dir)
Expand Down Expand Up @@ -366,8 +451,29 @@ def map_to_pspec(data):
case (checkpoint_manager, dataset_type, data_iterator) if dataset_type == "grain" and data_iterator and (
checkpoint_manager.directory / str(step) / "iter"
).exists():
grain_iter = grain.PyGrainCheckpointRestore(data_iterator.local_iterator)
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_iter)), None)
directory = checkpoint_manager.directory / str(step) / "iter"
process_count_jax = jax.process_count()
process_count_stored = len(list(directory.glob("process_*-of-*.json")))
if process_count_stored > process_count_jax:
assert isinstance(data_iterator, list), f"{process_count_stored} processes found in Grain checkpoint directory {directory}, but {process_count_jax} jax processes in this run, please set grain_checkpoint_scaling_factor accordingly."
assert process_count_stored / process_count_jax == len(data_iterator), f"{process_count_stored} processes found in checkpoint directory {directory}, but plan to restore {len(data_iterator)=} * {process_count_jax=} processes."
#restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args))
local_iterator_list = [x.local_iterator for x in data_iterator]
process_index_list = [jax.process_index() + i * process_count_jax for i in range(len(data_iterator))]
grain_iter_list = MaxtextGrainCheckpointRestore(local_iterator_list, process_index=process_index_list, process_count=process_count_stored)
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_iter_list))
# for i, data_iter in enumerate(data_iterator):
# process_index = jax.process_index() + i * process_count_jax
# grain_iter = MaxtextGrainCheckpointRestore(data_iter.local_iterator, process_index=process_index, process_count=process_count_stored)
# # Restore each iterator. The restore is done in-place on the iterator object.
# restored_state = checkpoint_manager.restore(step, args=Composite(iter=grain_iter))
return (restored_state, None)
elif process_count_stored == process_count_jax:
assert not isinstance(data_iterator, list), f"{process_count_stored} processes found in Grain checkpoint directory {directory}, equal to the number of jax process, please do not set grain_checkpoint_scaling_factor."
grain_iter = MaxtextGrainCheckpointRestore(data_iterator.local_iterator)
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_iter)), None)
else:
raise ValueError(f"Error restoring Grain checkpoint: {process_count_stored} cannot be restored to {process_count_jax}.")
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
Expand Down Expand Up @@ -518,15 +624,25 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
save_args=jax.tree.map(lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), state),
ocdbt_target_data_file_size=chunk_byte_size,
)
save_args_composite = {"items": checkpoint_args}
if config and config.dataset_type == "grain":
if not isinstance(data_iterator, list):
data_iterator = [data_iterator]
grain_iters_to_save = []
process_count_total = jax.process_count() * len(data_iterator)
for i, data_iter in enumerate(data_iterator):
process_index = jax.process_index() + i * jax.process_count()
grain_iters_to_save.append(
(data_iter.local_iterator, process_index, process_count_total)
)
save_args_composite["iter"] = MaxtextGrainCheckpointSave(item=grain_iters_to_save)

match (checkpoint_manager, config):
case (checkpoint_manager, _) if isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
replicator_error_handler(config)
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
case (_, config) if config and config.dataset_type == "grain":
grain_iter = grain.PyGrainCheckpointSave(data_iterator.local_iterator)
return checkpoint_manager.save(step, args=Composite(items=checkpoint_args, iter=grain_iter), force=force)
case _:
return checkpoint_manager.save(step, args=Composite(items=checkpoint_args), force=force)
return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force)

2 changes: 2 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ grain_eval_files: ''
grain_file_type: 'arrayrecord' # arrayrecord or parquet
grain_worker_count: 1
grain_worker_count_eval: 1
# experimental
grain_checkpoint_scaling_factor: 1
# for using pathways
colocated_python_data_input: False # experimental feature, under testing

Expand Down
14 changes: 13 additions & 1 deletion src/MaxText/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@ class DataLoader:
def __init__(self, config, mesh, data_iterator, goodput_recorder):
self.config = config
self.goodput_recorder = goodput_recorder
self.data_iterator = data_iterator
if isinstance(data_iterator, list):
self.data_iterator_list = data_iterator
self.data_iterator_index = 0
self.data_iterator = self.data_iterator_list[self.data_iterator_index]
else:
self.data_iterator = data_iterator
self.last_batch = None
self.input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

def update_data_iterator(self):
"""Update to the next data iterator in the list, if applicable."""
if hasattr(self, "data_iterator_list"):
self.data_iterator_index = (self.data_iterator_index + 1) % len(self.data_iterator_list)
self.data_iterator = self.data_iterator_list[self.data_iterator_index]

def load_next_batch(self):
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING):
Expand All @@ -47,6 +58,7 @@ def load_next_batch(self):
example_batch = self.last_batch
else:
example_batch = next(self.data_iterator)
self.update_data_iterator()
# Reshard data from loaded sharding to performant activation sharding
self.last_batch = jax.lax.with_sharding_constraint(example_batch, self.input_data_shardings)
self.check_example_batch()
Expand Down
9 changes: 7 additions & 2 deletions src/MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,13 @@ def generation_worker_fn(
max_utils.print_mem_stats("After params initialized")

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)

if config.save_checkpoint_on_completion and (config.steps - 1) % config.checkpoint_period != 0:
state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
elif checkpoint_manager is not None:
# in case the last checkpoint_period checkpoint is still in progress
checkpoint_manager.wait_until_finished()
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")
finally:
Expand Down
16 changes: 13 additions & 3 deletions src/MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def make_grain_train_iterator(
assert (
config.global_batch_size_to_load % global_mesh.size == 0
), "Batch size should be divisible by number of global devices."
if not config.colocated_python_data_input:
if not config.colocated_python_data_input and config.grain_checkpoint_scaling_factor == 1:
train_ds = get_datasets(
config.grain_train_files,
config.grain_file_type,
Expand Down Expand Up @@ -250,8 +250,18 @@ def make_grain_train_iterator(
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
)
global_shape = (config.global_batch_size_to_load, config.max_target_length)
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
if config.colocated_python_data_input:
global_shape = (config.global_batch_size_to_load, config.max_target_length)
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
elif config.grain_checkpoint_scaling_factor > 1:
train_dataloader_list = []
dataloading_host_count = len(process_indices) * config.grain_checkpoint_scaling_factor
for i in range(config.grain_checkpoint_scaling_factor):
dataloading_host_index = len(process_indices) * i + process_indices.index(jax.process_index())
train_ds = get_ds_fn(dataloading_host_index=dataloading_host_index, dataloading_host_count=dataloading_host_count)
train_dataloader = preprocessing_fn(train_ds)
train_dataloader_list.append(train_dataloader)
return [multihost_dataloading.MultiHostDataLoadIterator(x, global_mesh, config.generate_padding_batch_train) for x in train_dataloader_list]


def make_grain_eval_iterator(
Expand Down
9 changes: 7 additions & 2 deletions src/MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,8 +1124,13 @@ def setup_initial_state(
):
state = restored
else:
if "iter" in restored and restored["iter"] is not None:
data_iterator.local_iterator = restored["iter"]
# if "iter" in restored and restored["iter"] is not None:
# if isinstance(restored["iter"], list):
# for i, restored_iter in enumerate(restored["iter"]) :
# data_iterator[i].local_iterator = restored_iter
# else:
# data_iterator.local_iterator = restored["iter"]
# The update of data_iterator state happens in place, no need to assign explicitly
state = restored["items"]
else:
init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training)
Expand Down
6 changes: 5 additions & 1 deletion src/MaxText/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def train_loop(config, recorder, state=None):

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)

checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator)
if config.save_checkpoint_on_completion and (config.steps - 1) % config.checkpoint_period != 0:
checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator)
elif checkpoint_manager is not None:
# in case the last checkpoint_period checkpoint is still in progress
checkpoint_manager.wait_until_finished()
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")
finally:
Expand Down
5 changes: 4 additions & 1 deletion src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,12 @@ def train_loop(config, recorder, state=None):

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)

if config.save_checkpoint_on_completion:
if config.save_checkpoint_on_completion and (config.steps - 1) % config.checkpoint_period != 0:
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
elif checkpoint_manager is not None:
# in case the last checkpoint_period checkpoint is still in progress
checkpoint_manager.wait_until_finished()
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")
finally:
Expand Down
Loading