Skip to content

Gradient norm clipping with pipeline parallelism (PP) #596

@zijian-hu

Description

@zijian-hu

Dear torchtitan team, I have a question regarding gradient norm clipping when using pipeline parallelism (PP) potentially combined with FSDP/DP/TP.

For simplicity, let's assume each process/GPU has single PP stage. My understanding is that since the model is manually sharded, calling torch.nn.utils.clip_grad_norm_ will only compute the grad norm based on the modules of the current PP stage.

torchtitan/train.py

Lines 298 to 302 in eef8bb2

# clip gradients
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)

Since grad norm clipping requires computing the norm over the entire model (across all PP stages), does it mean we need to manually aggregate/reduce the grad norm across PP stages before the normalization? If so, what would be the correct approach for doing this?

Any clarification or guidance would be greatly appreciated!

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingrelease blockingIssues that are blocking the milestone / release completion

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions