diff --git a/train.py b/train.py index 21dd9f8b8..3ceacf270 100644 --- a/train.py +++ b/train.py @@ -154,6 +154,8 @@ def loss_fn(pred, labels): pp_schedule, model_parts = models_pipelining_fns[model_name]( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) + # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead + del model # For PP with looped schedules, each item in model_parts is one stage-model-chunk. # We need to iterate through model_parts to apply SPMD parallelisms, compilation, @@ -269,11 +271,12 @@ def loss_fn(pred, labels): optimizers.zero_grad() # apply context parallelism if cp is enabled + # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( utils.create_context_parallel_ctx( cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels, model.freqs_cis], - cp_seq_dims=[1, 1, 0], + cp_buffers=[input_ids, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={input_ids, labels}, cp_rotate_method=job_config.experimental.context_parallel_rotate_method, )