Skip to content

Commit

Permalink
Add zero bubble to test runner
Browse files Browse the repository at this point in the history
ghstack-source-id: 2cf5b500cb55c78046ad77acb8ecfe6497650961
Pull Request resolved: pytorch#605
  • Loading branch information
H-Huang committed Oct 10, 2024
1 parent cec7fb6 commit 205983a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
13 changes: 13 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ def build_test_list():
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_schedule InterleavedZeroBubble",
],
],
"PP looped zero bubble test",
"pp_looped_zero_bubble",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down
5 changes: 2 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,10 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1F1B", "GPipe", "Interleaved1F1B", "FlexibleInterleaved1F1B"],
default="1F1B",
help="""
Specify the Pipeline Parallel schedule to use.
Specify the Pipeline Parallel schedule to use. The supported schedules are:
https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
Expand Down
15 changes: 6 additions & 9 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
# LICENSE file in the root directory of this source tree.
from typing import Tuple

from torch.distributed.pipelining import (
ScheduleFlexibleInterleaved1F1B,
ScheduleInterleaved1F1B,
)

from torch.distributed.pipelining.schedules import (
get_schedule_class,
PipelineScheduleMulti,
Expand Down Expand Up @@ -61,13 +56,15 @@ def generate_split_points(job_config, pp_dim, model_config):


def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
if schedule_class in [ScheduleInterleaved1F1B, ScheduleFlexibleInterleaved1F1B]:
looped_schedule = True
if schedule_class in [PipelineScheduleSingle, PipelineScheduleMulti]:
raise ValueError(
f"{schedule_class} is not supported as we do not support custom CSV schedules."
)

looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
Expand Down

0 comments on commit 205983a

Please sign in to comment.