-
Notifications
You must be signed in to change notification settings - Fork 296
[moe training] Add TP support for routed experts #2473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2473
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 PendingAs of commit b9e58fe with merge base 01f7352 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
from torch.nn import functional as F | ||
|
||
# this feature requires CUDA and SM89+ | ||
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the bf16 group gemm working on A100? If yes, I would vote for adding an emulation mode and running this test in emulation mode. We do this for float8 and MX training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into this and add emulation if bf16 grouped gemm builds on a100
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor | ||
|
||
|
||
def _validate_model_conversion( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why do we need recursion to check this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a generic way of checking all target FQNs were converted properly, and verifying all non-target FQNs were correctly not converted. It can easily be applied when we extend tests to other MoE models as well beyond just the torchtitan llama4 one I started with.
105689f
to
bb9626e
Compare
cb1eae9
to
7fed93e
Compare
e92f92d
to
04a3d2f
Compare
We can remove this assertion, TP support for float8 rowwise MoE training was added in this PR stack: pytorch/ao#2473
Stack
Summary
offs
optional, to handle shared_expert case where num_experts=1 (no group offsets needed since there's only 1 token group).Test plan
./test/prototype/moe_training/test_tp.sh