-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ZBVZeroBubbleSchedule #817
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,7 +151,12 @@ def loss_fn(pred, labels): | |
# apply parallelisms and initialization | ||
if parallel_dims.pp_enabled: | ||
# apply PT-D Pipeline Parallel | ||
pp_schedule, model_parts = models_pipelining_fns[model_name]( | ||
( | ||
pp_schedule, | ||
model_parts, | ||
has_first_stage, | ||
has_last_stage, | ||
) = models_pipelining_fns[model_name]( | ||
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn | ||
) | ||
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead | ||
|
@@ -285,22 +290,19 @@ def loss_fn(pred, labels): | |
|
||
if parallel_dims.pp_enabled: | ||
# Pipeline Parallel forward / backward inside step() call | ||
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 | ||
|
||
with train_context(optional_context_parallel_ctx): | ||
if pp_mesh.get_local_rank() == 0: | ||
pp_schedule.step(input_ids) | ||
elif is_last_stage: | ||
losses = [] | ||
pp_schedule.step(target=labels, losses=losses) | ||
targets = labels if has_last_stage else None | ||
losses = [] if has_last_stage else None | ||
H-Huang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if has_first_stage: | ||
pp_schedule.step(input_ids, target=targets, losses=losses) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if a schedule has There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops yeah, that was the issue. Updated it and will let the CI run again |
||
else: | ||
pp_schedule.step() | ||
pp_schedule.step(target=targets, losses=losses) | ||
|
||
# accumulate losses across pipeline microbatches | ||
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
loss = ( | ||
torch.mean(torch.stack(losses)).to(device) | ||
if is_last_stage | ||
if has_last_stage | ||
else torch.tensor([-1.0], device=device) | ||
) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kinda nit picking but i feel like if the stage object inside the schedule already knows that it is first or last, we can avoid having the logic in the training loop too.
otoh it seems nice to be explicit at the train.py layer on whether we are asking to compute loss or not.
thoughts?
@tianyu-l
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels nice when we only explicitly pass in meaningful targets/losses when we are not sure if they'll be properly accessed, so I'm OK with these if-else statements.
But how different is
input_ids
? Can we just unify everything intopp_schedule.step(input_ids, target=targets, losses=losses)
and pass
input_ids = None
whennot has_first_stage
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't do
input_ids=None
right now since we have logic that automatically splits all*args
into microbatches. For example if the user wants to dostep(tensors, None)
that would be split up into microbatches of(tensor1, None)
,(tensor2, None)
, ... We could update the splitting logic but not sure if it is worth it