From 205983a09f66eb6841cf74b074c9e4dd3c109c03 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 10 Oct 2024 10:06:13 -0700 Subject: [PATCH] Add zero bubble to test runner ghstack-source-id: 2cf5b500cb55c78046ad77acb8ecfe6497650961 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/605 --- test_runner.py | 13 +++++++++++++ torchtitan/config_manager.py | 5 ++--- torchtitan/parallelisms/pipelining_utils.py | 15 ++++++--------- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/test_runner.py b/test_runner.py index a937e517..7221e587 100755 --- a/test_runner.py +++ b/test_runner.py @@ -149,6 +149,19 @@ def build_test_list(): requires_seed_checkpoint=True, ngpu=4, ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 4", + "--experimental.pipeline_parallel_schedule InterleavedZeroBubble", + ], + ], + "PP looped zero bubble test", + "pp_looped_zero_bubble", + requires_seed_checkpoint=True, + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index ca1b9625..88e51f02 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -299,11 +299,10 @@ def __init__(self): self.parser.add_argument( "--experimental.pipeline_parallel_schedule", type=str, - choices=["1F1B", "GPipe", "Interleaved1F1B", "FlexibleInterleaved1F1B"], default="1F1B", help=""" - Specify the Pipeline Parallel schedule to use. - + Specify the Pipeline Parallel schedule to use. The supported schedules are: + https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161. The schedule must be compatible with the split points and stages_per_rank. Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks, diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 4bf657e0..c154934f 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -5,11 +5,6 @@ # LICENSE file in the root directory of this source tree. from typing import Tuple -from torch.distributed.pipelining import ( - ScheduleFlexibleInterleaved1F1B, - ScheduleInterleaved1F1B, -) - from torch.distributed.pipelining.schedules import ( get_schedule_class, PipelineScheduleMulti, @@ -61,13 +56,15 @@ def generate_split_points(job_config, pp_dim, model_config): def build_pipeline_schedule(job_config, stages, loss_fn): - looped_schedule = False - schedule_class = get_schedule_class( job_config.experimental.pipeline_parallel_schedule ) - if schedule_class in [ScheduleInterleaved1F1B, ScheduleFlexibleInterleaved1F1B]: - looped_schedule = True + if schedule_class in [PipelineScheduleSingle, PipelineScheduleMulti]: + raise ValueError( + f"{schedule_class} is not supported as we do not support custom CSV schedules." + ) + + looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) logger.info( f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" )