Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PP clip_grad_norm #649

Merged
merged 9 commits into from
Nov 16, 2024

Conversation

zijian-hu
Copy link
Contributor

This PR fixes the pipeline parallel incomplete gradient norm issue. Refer to #596 for more details.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 24, 2024
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be easy / feasible, but can we find a way to

  1. first get global total_norm
  2. then make a function call to (some variant of) the torch.nn.utils.clip_grad_norm_ with the global total_norm

The idea is to avoid the largely repeated code in pytorch.

train.py Outdated Show resolved Hide resolved
@wconstab wconstab changed the title Fix PP clip_grad_nrom Fix PP clip_grad_norm Oct 31, 2024
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry requesting changes since total_norm is computed wrt. parameters instead of gradients

torchtitan/clip_grad_nrom.py Outdated Show resolved Hide resolved
Comment on lines 63 to 65
total_norm **= norm_type
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only correct for p-norm where p is int/float (e.g. this does not work for inf-norm).

Intuitively, we should somehow allow us to change total_norm to have _NormPartial(norm_type) placement over the PP mesh, and we should redistribute it to Replicate() to reuse DTensor's own logic. cc: @tianyu-l @wz337 @XilunWu is there any way to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu in torchtitan, PP is done by first manually splitting the Llama model then wrap in PP wrapper. If my understanding is correct, none of them converts models weights into DTensor with PP mesh placement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, will just add a if-statement to check for inf norm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP is done by first manually splitting the Llama model then wrap in PP wrapper. If my understanding is correct, none of them converts models weights into DTensor with PP mesh placement

this is true, however after splitting the llama model we do also apply SPMD parallelisms (TP and DP) which will use DTensor for parameters. So by the time we are clipping norms, parameters would be DTensors. (if we use PP only and no TP/DP, then there would be no DTensor parameters)

Copy link
Contributor Author

@zijian-hu zijian-hu Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab You are right about SPMD (TP and/or DP) but the DTensor in that case has placement on TP/DP mesh not on PP mesh. We still need to do some form of reduce either explicitly (same as my code) or wrap the total_norm in DTensor on PP mesh (maybe via DTensor.redistribute?)

torchtitan/utils.py Outdated Show resolved Hide resolved

if pp_mesh is not None:
if isinstance(total_norm, DTensor):
# will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab the if statement are now separated. Please let me know if it make sense to you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good, thanks!

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

# if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
# we can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor
total_norm = total_norm.redistribute(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use total_norm.full_tensor() instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang Good suggestion. Updated!

@zijian-hu
Copy link
Contributor Author

sorry requesting changes since total_norm is computed wrt. parameters instead of gradients

@awgu please let me know if the update make sense and feel free to resolve the change request if you believe it is good enough

Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

@H-Huang H-Huang merged commit 046de56 into pytorch:main Nov 16, 2024
5 checks passed
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Nov 26, 2024
This PR fixes the pipeline parallel incomplete gradient norm issue.
Refer to pytorch#596 for more
details.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants