Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ZBVZeroBubbleSchedule #817

Merged
merged 2 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def build_test_list():
"pp_looped_zero_bubble",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule ZBVZeroBubble",
"--experimental.pipeline_parallel_microbatches 8",
],
],
"PP zero bubble test (v shaped)",
"pp_zbv",
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down
24 changes: 21 additions & 3 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import (
get_schedule_class,
ScheduleZBVZeroBubble,
)
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
Expand Down Expand Up @@ -43,7 +46,16 @@ def pipeline_llama(

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, models, has_first_stage, has_last_stage


def pipeline_llama_manual_split(
Expand Down Expand Up @@ -103,7 +115,13 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"

for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
Expand Down
21 changes: 11 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
(
pp_schedule,
model_parts,
has_first_stage,
has_last_stage,
) = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand Down Expand Up @@ -285,22 +290,18 @@ def loss_fn(pred, labels):

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context(optional_context_parallel_ctx):
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
targets, losses = (labels, []) if has_last_stage else (None, None)
if has_first_stage:
pp_schedule.step(input_ids, target=targets, losses=losses)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if a schedule has has_last_stage = True and has_first_stage = False for the output layer -- will it miss the chance to feed in losses?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yeah, that was the issue. Updated it and will let the CI run again

else:
pp_schedule.step()
pp_schedule.step(target=targets, losses=losses)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
if has_last_stage
else torch.tensor([-1.0], device=device)
)
else:
Expand Down