-
Notifications
You must be signed in to change notification settings - Fork 433
Description
Dear torchtitan team, I have a question regarding gradient norm clipping when using pipeline parallelism (PP) potentially combined with FSDP/DP/TP
.
For simplicity, let's assume each process/GPU has single PP stage. My understanding is that since the model is manually sharded, calling torch.nn.utils.clip_grad_norm_
will only compute the grad norm based on the modules of the current PP stage.
Lines 298 to 302 in eef8bb2
# clip gradients | |
for m in model_parts: | |
torch.nn.utils.clip_grad_norm_( | |
m.parameters(), job_config.training.max_norm, foreach=True | |
) |
Since grad norm clipping requires computing the norm over the entire model (across all PP stages), does it mean we need to manually aggregate/reduce the grad norm across PP stages before the normalization? If so, what would be the correct approach for doing this?
Any clarification or guidance would be greatly appreciated!