Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can we use te.Linear with weight parallel? #1532

Open
zigzagcai opened this issue Mar 4, 2025 · 3 comments
Open

How can we use te.Linear with weight parallel? #1532

zigzagcai opened this issue Mar 4, 2025 · 3 comments

Comments

@zigzagcai
Copy link

zigzagcai commented Mar 4, 2025

Hi developers,

Thanks for introducing such a great project that enables FP8 training.

In my training framework, we have a weight parallel implementation that do weight all-gather and reduce-scatter like ZeRO3. From the weight parallel implementation we can find that in the forward pass, we all-gather weight do call the linear_forward_op (which is actually torch.nn.functional.Linear).

But when I check the code of te.Linear, there is a torch.autograd.Function named _Linear that handles FP8 computation.

So, I just wonder how can we integrate te.Linear with our weight parallel implementation? From my understanding, the forward op and backward op that used in our weight parallel implementation is dependent on torch.nn.functional.Linear, which is not compatible with the op that used in te._Linear.

Thanks in advance if anybody could provide some hints!

cc @ksivaman @timmoon10 @cyanguwa

@timmoon10
Copy link
Collaborator

PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization.

If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs:



These are experimental though and we can't make any guarantees on the stability of their APIs.

@zigzagcai
Copy link
Author

zigzagcai commented Mar 14, 2025

PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization.

If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs:

TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py

Line 335 in 2ad5da9

def _functional_forward(

TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py

Line 539 in 2ad5da9

def _functional_backward(

These are experimental though and we can't make any guarantees on the stability of their APIs.

Hi @timmoon10 ,

Thanks for your reply! I have tried your approach, where I switch the default linear fwd/bwd ops with TransformerEngine BasicLinear._functional_forward and BasicLinear._functional_backward. But from the trace, I cannot found any FP8 GEMM kernels. It seems the _functional_forward and _functional_backward still calls BF16 GEMM kernels, not the FP8 GEMM.

class WPFusedDenseFunc(torch.autograd.Function):
    "FusedDenseFunc for weigth parallel, which is optimized based on flash implementation."

    @staticmethod
    @custom_fwd
    def forward(
        ctx,
        x: torch.Tensor,
        weight: torch.Tensor,
        bias: Optional[torch.Tensor],
        module: nn.Module,
        communicator: WPCommunicator,
        return_residual=False,
    ):
        ctx.compute_weight_gradient = weight.requires_grad
        ctx.return_residual = return_residual
        ctx.module = module
        ctx.communicator = communicator
        
        assert bias is None
        assert not return_residual

        if torch.is_autocast_enabled():
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
        x = x.contiguous()

        total_weight = communicator.weight_hook(weight, module=module)
        total_bias = bias if bias is None else communicator.weight_hook(bias, module=module, is_bias=True)

        if torch.is_autocast_enabled():
            total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
            if total_bias:
                total_bias.to(dtype=torch.get_autocast_gpu_dtype())

        total_weight = total_weight.contiguous()
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = batch_shape.numel()
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
            raise RuntimeError("fused_dense only supports matrix dims <= 2M")

        output, _, _ = BasicLinear._functional_forward(input=x, weight=total_weight, bias=total_bias)

        # release memory
        del total_weight
        del total_bias

        # parallel strategy-specific communication callback 2.
        # see more details in the communicator for different parallel strategies.
        # gather seq dim when head parallel_output is False
        if hasattr(communicator, "output_hook"):
            output, _ = communicator.output_hook(output, async_op=False)

        saved_x = None if ctx.compute_weight_gradient is False else x
        ctx.save_for_backward(saved_x, weight, bias)
        
        return output if not return_residual else (output, x)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output, *args):
        module: nn.Module = ctx.module
        communicator: WPCommunicator = ctx.communicator
        x, weight, bias = ctx.saved_tensors

        # parallel strategy-specific communication callback 3.
        # see more details in the communicator for different parallel strategies.
        if hasattr(communicator, "grad_output_hook"):
            grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False)

        grad_output = grad_output.contiguous()
        if ctx.return_residual:
            (grad_input,) = args
            grad_input = grad_input.contiguous()

        batch_shape = grad_output.shape[:-1]
        batch_dim = batch_shape.numel()
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])

        total_weight = communicator.weight_hook(weight, module=module)

        # compute weight grad
        if ctx.needs_input_grad[1]:
            assert ctx.compute_weight_gradient
            x = x.reshape(batch_dim, x.shape[-1])
            _, grad_weight = BasicLinear._functional_backward(grad_output=grad_output, input=x, weight=total_weight)
            grad_weight, grad_weight_sync = communicator.grad_hook(
                grad_weight, async_op=True, module=module, is_bias=False
            )
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None

        if ctx.needs_input_grad[0]:
            grad_input, _, _ = BasicLinear._functional_forward(input=grad_output, weight=total_weight.t())
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
        else:
            grad_input = None

        del total_weight

        if ctx.needs_input_grad[1]:
            grad_weight_sync.wait()

        return grad_input, grad_weight, None, None, None, None, None

CPU Trace:
Image

CUDA Trace:
Image

Could you please share some insights about how to enable FP8 GEMM kernels with this internal API? @timmoon10
Thanks in advance!

@zigzagcai
Copy link
Author

zigzagcai commented Mar 17, 2025

The basic idea of our ZeRO3 weight parallel implementation:
In WPFusedDenseFunc https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L171-L315, we all-gather weights in the fwd pass, then all-gather weights and reduce-scatter gradients in bwd pass. And we just apply this customized autograd function to https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L532-L678

So, I just wander how could we integrate TE FP8 with our customized ZeRO3 weight parallel implementation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants