Description
Hi developers,
Thanks for introducing such a great project that enables FP8 training.
In my training framework, we have a weight parallel implementation that do weight all-gather and reduce-scatter like ZeRO3. From the weight parallel implementation we can find that in the forward pass, we all-gather weight do call the linear_forward_op (which is actually torch.nn.functional.Linear
).
But when I check the code of te.Linear, there is a torch.autograd.Function
named _Linear that handles FP8 computation.
So, I just wonder how can we integrate te.Linear
with our weight parallel implementation? From my understanding, the forward op and backward op that used in our weight parallel implementation is dependent on torch.nn.functional.Linear
, which is not compatible with the op that used in te._Linear
.
Thanks in advance if anybody could provide some hints!