From 3996b634a58c7b7de3f7f7b815058245d8c69e01 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 10 Feb 2025 15:57:48 -0500 Subject: [PATCH] Support ZBVZeroBubbleSchedule (#817) This is dependent on the changes in this pytorch stack: https://github.com/pytorch/pytorch/pull/146217 Add support for running `ZBVZeroBubbleSchedule` and v-shaped CSV schedules in torchtitan Fixes https://github.com/pytorch/torchtitan/issues/774 --------- Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com> --- tests/integration_tests.py | 12 ++++++++++++ torchtitan/parallelisms/pipeline_llama.py | 24 ++++++++++++++++++++--- train.py | 21 ++++++++++---------- 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 1bdd5df91..c067ac37e 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -139,6 +139,18 @@ def build_test_list(): "pp_looped_zero_bubble", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_schedule ZBVZeroBubble", + "--experimental.pipeline_parallel_microbatches 8", + ], + ], + "PP zero bubble test (v shaped)", + "pp_zbv", + ngpu=2, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index 6605a57d6..8fe892ab3 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -13,7 +13,10 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage - +from torch.distributed.pipelining.schedules import ( + get_schedule_class, + ScheduleZBVZeroBubble, +) from torchtitan.config_manager import JobConfig from torchtitan.logging import logger from torchtitan.models.llama.model import ModelArgs @@ -43,7 +46,16 @@ def pipeline_llama( pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) - return pp_schedule, models + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, models, has_first_stage, has_last_stage def pipeline_llama_manual_split( @@ -103,7 +115,13 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal stages = [] models = [] - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): + + schedule_class = get_schedule_class( + job_config.experimental.pipeline_parallel_schedule + ) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): start_layer = splits[stage_idx - 1] if stage_idx > 0 else None stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None stage, model_chunk = _build_stage( diff --git a/train.py b/train.py index 761393f7f..2e2a7ec4b 100644 --- a/train.py +++ b/train.py @@ -151,7 +151,12 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - pp_schedule, model_parts = models_pipelining_fns[model_name]( + ( + pp_schedule, + model_parts, + has_first_stage, + has_last_stage, + ) = models_pipelining_fns[model_name]( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead @@ -285,22 +290,18 @@ def loss_fn(pred, labels): if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with train_context(optional_context_parallel_ctx): - if pp_mesh.get_local_rank() == 0: - pp_schedule.step(input_ids) - elif is_last_stage: - losses = [] - pp_schedule.step(target=labels, losses=losses) + targets, losses = (labels, []) if has_last_stage else (None, None) + if has_first_stage: + pp_schedule.step(input_ids, target=targets, losses=losses) else: - pp_schedule.step() + pp_schedule.step(target=targets, losses=losses) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( torch.mean(torch.stack(losses)).to(device) - if is_last_stage + if has_last_stage else torch.tensor([-1.0], device=device) ) else: