Skip to content

Commit

Permalink
Fix PP+CP handling freqs_cis buffer
Browse files Browse the repository at this point in the history
When developing test_pp_cp and chatting with @fegin, we realized the
freqs_cis buffers are not being handled correctly in torchtitan for the
pipelining case.

CP needs to modify the freqs_cis buffer to account for sharding on seq
dim, but in the previous titan code this was implemented incorrectly.
`model.freqs_cis` was passed to CP for sharding, but pipelining does not
use `model` at all, it uses the different stage-models contained in
`model_parts` list.  The fix is to tell CP context about each freqs_cis
buffer inside `model_parts` models.

Alternatively we could tie the freqs_cis buffers for each pp stage
together, by explicitly doing so after calling init_weights per
pp-stage.  However this is of limited value so we skip it.

[ghstack-poisoned]
  • Loading branch information
wconstab committed Jan 15, 2025
1 parent 82f7387 commit a9a6d00
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def loss_fn(pred, labels):
pp_schedule, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
del model

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
Expand Down Expand Up @@ -269,11 +271,12 @@ def loss_fn(pred, labels):
optimizers.zero_grad()

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
optional_context_parallel_ctx = (
utils.create_context_parallel_ctx(
cp_mesh=world_mesh["cp"],
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_buffers=[input_ids, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={input_ids, labels},
cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
)
Expand Down

0 comments on commit a9a6d00

Please sign in to comment.