Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ZBVZeroBubbleSchedule #817

Merged
merged 2 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def build_test_list():
"pp_looped_zero_bubble",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule ZBVZeroBubble",
"--experimental.pipeline_parallel_microbatches 8",
],
],
"PP zero bubble test (v shaped)",
"pp_zbv",
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down
24 changes: 21 additions & 3 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import (
get_schedule_class,
ScheduleZBVZeroBubble,
)
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
Expand Down Expand Up @@ -43,7 +46,16 @@ def pipeline_llama(

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, models, has_first_stage, has_last_stage


def pipeline_llama_manual_split(
Expand Down Expand Up @@ -103,7 +115,13 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"

for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
Expand Down
22 changes: 12 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
(
pp_schedule,
model_parts,
has_first_stage,
has_last_stage,
) = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand Down Expand Up @@ -285,22 +290,19 @@ def loss_fn(pred, labels):

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context(optional_context_parallel_ctx):
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
targets = labels if has_last_stage else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda nit picking but i feel like if the stage object inside the schedule already knows that it is first or last, we can avoid having the logic in the training loop too.

otoh it seems nice to be explicit at the train.py layer on whether we are asking to compute loss or not.

thoughts?
@tianyu-l

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels nice when we only explicitly pass in meaningful targets/losses when we are not sure if they'll be properly accessed, so I'm OK with these if-else statements.

But how different is input_ids? Can we just unify everything into pp_schedule.step(input_ids, target=targets, losses=losses)
and pass input_ids = None when not has_first_stage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do input_ids=None right now since we have logic that automatically splits all *args into microbatches. For example if the user wants to do step(tensors, None) that would be split up into microbatches of (tensor1, None), (tensor2, None), ... We could update the splitting logic but not sure if it is worth it

losses = [] if has_last_stage else None
H-Huang marked this conversation as resolved.
Show resolved Hide resolved
if has_first_stage:
pp_schedule.step(input_ids, target=targets, losses=losses)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if a schedule has has_last_stage = True and has_first_stage = False for the output layer -- will it miss the chance to feed in losses?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yeah, that was the issue. Updated it and will let the CI run again

else:
pp_schedule.step()
pp_schedule.step(target=targets, losses=losses)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
if has_last_stage
else torch.tensor([-1.0], device=device)
)
else:
Expand Down