-
Notifications
You must be signed in to change notification settings - Fork 389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can we use te.Linear with weight parallel? #1532
Comments
PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization. If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs:
These are experimental though and we can't make any guarantees on the stability of their APIs. |
Hi @timmoon10 , Thanks for your reply! I have tried your approach, where I switch the default linear fwd/bwd ops with TransformerEngine
Could you please share some insights about how to enable FP8 GEMM kernels with this internal API? @timmoon10 |
The basic idea of our ZeRO3 weight parallel implementation: So, I just wander how could we integrate TE FP8 with our customized ZeRO3 weight parallel implementation? |
Hi developers,
Thanks for introducing such a great project that enables FP8 training.
In my training framework, we have a weight parallel implementation that do weight all-gather and reduce-scatter like ZeRO3. From the weight parallel implementation we can find that in the forward pass, we all-gather weight do call the linear_forward_op (which is actually
torch.nn.functional.Linear
).But when I check the code of te.Linear, there is a
torch.autograd.Function
named _Linear that handles FP8 computation.So, I just wonder how can we integrate
te.Linear
with our weight parallel implementation? From my understanding, the forward op and backward op that used in our weight parallel implementation is dependent ontorch.nn.functional.Linear
, which is not compatible with the op that used inte._Linear
.Thanks in advance if anybody could provide some hints!
cc @ksivaman @timmoon10 @cyanguwa
The text was updated successfully, but these errors were encountered: