Skip to content

How can we use te.Linear with weight parallel? #1532

Open
@zigzagcai

Description

@zigzagcai

Hi developers,

Thanks for introducing such a great project that enables FP8 training.

In my training framework, we have a weight parallel implementation that do weight all-gather and reduce-scatter like ZeRO3. From the weight parallel implementation we can find that in the forward pass, we all-gather weight do call the linear_forward_op (which is actually torch.nn.functional.Linear).

But when I check the code of te.Linear, there is a torch.autograd.Function named _Linear that handles FP8 computation.

So, I just wonder how can we integrate te.Linear with our weight parallel implementation? From my understanding, the forward op and backward op that used in our weight parallel implementation is dependent on torch.nn.functional.Linear, which is not compatible with the op that used in te._Linear.

Thanks in advance if anybody could provide some hints!

cc @ksivaman @timmoon10 @cyanguwa

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions