diff --git a/train.py b/train.py index d7b384c6..2e2a7ec4 100644 --- a/train.py +++ b/train.py @@ -291,8 +291,7 @@ def loss_fn(pred, labels): if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call with train_context(optional_context_parallel_ctx): - targets = labels if has_last_stage else None - losses = [] if has_last_stage else None + targets, losses = (labels, []) if has_last_stage else (None, None) if has_first_stage: pp_schedule.step(input_ids, target=targets, losses=losses) else: