From 240240617267cff76178a7f5da58a93806e5a6d2 Mon Sep 17 00:00:00 2001
From: Alp Dener <adener@nvidia.com>
Date: Mon, 13 Jan 2025 14:24:08 -0600
Subject: [PATCH] [PyTorch] Adding TP overlap support for `te.Linear` with
 `parallel_mode="column"` (#1343)

* support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward

Signed-off-by: Alp Dener <adener@nvidia.com>

* implemented TP overlap support for column-parallel te.Linear

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests

Signed-off-by: Alp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* improved error messages for internal failure to infer TP overlap options in te.Linear

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed linting errors

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed incorrect TP overlap option asserts

Signed-off-by: Alp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <adener@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 .../distributed/run_layer_with_overlap.py     |  62 +++-
 .../distributed/test_comm_gemm_overlap.py     |  24 +-
 transformer_engine/pytorch/module/linear.py   | 347 +++++++++++++-----
 3 files changed, 322 insertions(+), 111 deletions(-)

diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py
index e49174c24f..5a67bd616a 100644
--- a/tests/pytorch/distributed/run_layer_with_overlap.py
+++ b/tests/pytorch/distributed/run_layer_with_overlap.py
@@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
     kwargs["ub_overlap_ag"] = not reference
 
     if config.layer_type is te.Linear:
-        input_shape[2] = hidden_size // tp_size
-        args.append(hidden_size)
-        kwargs["parallel_mode"] = "row"
-        kwargs["ub_overlap_rs"] = not reference
-        kwargs["ub_name"] = "proj"
+        if config.linear_parallel_mode == "row":
+            input_shape[2] = hidden_size // tp_size
+            args.append(hidden_size)
+            kwargs["ub_overlap_rs"] = not reference
+        elif config.linear_parallel_mode == "column":
+            input_shape[0] = config.seq_length // tp_size
+            args.append(3 * hidden_size)
+            kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference
+            kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
+            kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
+        kwargs["parallel_mode"] = config.linear_parallel_mode
+        kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv"
     else:
         input_shape[0] = config.seq_length // tp_size
-        kwargs["ub_bulk_wgrad"] = not reference
-        kwargs["ub_bulk_dgrad"] = not reference
+        kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
+        kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
+        kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
         if config.layer_type is te.LayerNormLinear:
             args.append(3 * hidden_size)
             kwargs["parallel_mode"] = "column"
@@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None):
     parser.add_argument(
         "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
     )
+    parser.add_argument(
+        "--linear-parallel-mode",
+        type=str.lower,
+        default="row",
+        choices=["row", "column"],
+        help="Parallel mode for te.Linear.",
+    )
+    parser.add_argument(
+        "--overlap-rs-dgrad",
+        action="store_true",
+        default=False,
+        help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.",
+    )
     parser.add_argument(
         "--debug",
         action="store_true",
@@ -230,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
     dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
 
     # Intialize userbuffers
+    ub_cfgs = None
+    if opts.overlap_rs_dgrad:
+        ub_cfgs = {
+            "proj_dgrad": {"method": "ring_exchange"},
+            "qkv_dgrad": {"method": "ring_exchange"},
+        }
     te.module.base.initialize_ub(
         [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
         WORLD_SIZE,
         use_fp8=opts.fp8,
         dtype=torch.bfloat16,
         bootstrap_backend=opts.bootstrap_backend,
+        ub_cfgs=ub_cfgs,
     )
 
     # Initialize the Transformer Engine layer with overlap
@@ -314,27 +342,29 @@ def run_fwd_bwd(model, x):
             ref_grads.append(ref_param.grad)
 
     # Make sure we have the same number of gradients
-    numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
+    num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
     if len(test_grads) != len(ref_grads):
-        numerics_failed[0] = 1
+        num_grads_failed[0] = 1
         numerics_info = (
             "NUMERICAL CHECK FAILED: Incorrect number of gradients, "
             + f"expected {len(ref_grads)} but got {len(test_grads)}."
         )
         dist_print(numerics_info, src=WORLD_RANK, error=True)
-    dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
+    dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world)
 
     # Now validate accuracy
-    if not bool(numerics_failed.item()):
+    numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda")
+    if not bool(num_grads_failed.item()):
         for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
             rtol = 0.125 if opts.fp8 else 0.025
             atol = 0.0625 if opts.fp8 else 0.00125
             grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
             dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
-            numerics_failed[0] = int(grad_failed)
-            dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
-            if bool(numerics_failed.item()):
-                break
+            numerics_failed[i] = int(grad_failed)
+        return_code = torch.max(numerics_failed)
+        dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world)
+    else:
+        return_code = num_grads_failed
 
     te.module.base.destroy_ub()
     dist_print("Destroying Userbuffers objects...", debug=True)
@@ -344,7 +374,7 @@ def run_fwd_bwd(model, x):
     if opts.debug and WORLD_RANK == 0:
         print("Exiting...\n", end="", flush=True)
 
-    return numerics_failed[0].item()
+    return return_code.item()
 
 
 if __name__ == "__main__":
diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py
index 240e396534..c285da7fbd 100644
--- a/tests/pytorch/distributed/test_comm_gemm_overlap.py
+++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py
@@ -21,8 +21,10 @@
 BATCH_SIZE: int = 2
 NUM_HEADS: int = 12
 HEAD_DIM: int = 64
+
+# NOTE: te.Linear is intentionally omitted here and manually added later for testing both
+#       row and column parallel layouts.
 TE_LAYERS = [
-    te.Linear,
     te.LayerNormLinear,
     te.LayerNormMLP,
     te.MultiheadAttention,
@@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg
         raise AssertionError(result.stderr.decode())
 
 
-def _run_layer_with_overlap(layer_type, fp8, fp8_init):
+def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
     test_path = TEST_ROOT / "run_layer_with_overlap.py"
     test_cmd = LAUNCH_CMD + [
         str(test_path),
@@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init):
         f"--head-dim={HEAD_DIM}",
         f"--layer-type={layer_type}",
     ]
+    if layer_type == te.Linear.__name__:
+        test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")
 
     if fp8:
         if not fp8_available:
@@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections):
 
 
 @pytest.mark.parametrize(
-    "layer_type",
-    [layer.__name__ for layer in TE_LAYERS],
-    ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
+    "layer_type,linear_parallel_mode",
+    (
+        [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")]
+        + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
+    ),
+    ids=(
+        [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "]
+        + [(" " + layer.__name__ + " ") for layer in TE_LAYERS]
+    ),
 )
 @pytest.mark.parametrize(
     "fp8,fp8_init",
@@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections):
         " FP8  GEMM - FP8  PARAMS ",
     ],
 )
-def test_layers_with_overlap(layer_type, fp8, fp8_init):
+def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
     """
     Test Transformer Engine layers with comm+GEMM overlap.
     """
-    _run_layer_with_overlap(layer_type, fp8, fp8_init)
+    _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init)
diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py
index 5fd4dd2fc9..2262d23832 100644
--- a/transformer_engine/pytorch/module/linear.py
+++ b/transformer_engine/pytorch/module/linear.py
@@ -3,6 +3,8 @@
 # See LICENSE for license information.
 
 """Linear API"""
+from functools import reduce
+from operator import mul as multiply_op
 from typing import Any, Callable, Dict, Optional, Tuple, Union
 
 import torch
@@ -43,7 +45,7 @@
     fp8_cast_transpose_fused,
     cast_to_fp8,
 )
-from ..constants import GemmParallelModes, dist_group_type
+from ..constants import GemmParallelModes, dist_group_type, TE_DType
 from ..jit import no_torch_dynamo
 from ..graph import is_graph_capturing
 from ..float8_tensor import Float8Tensor
@@ -80,8 +82,12 @@ def forward(
         activation_dtype: torch.dtype,
         parallel_mode: Union[str, None],
         is_grad_enabled: bool,
-        ub_overlap_rs: bool,
-        ub_overlap_ag: bool,
+        ub_overlap_rs_fprop: bool,
+        ub_overlap_ag_dgrad: bool,
+        ub_overlap_ag_fprop: bool,
+        ub_overlap_rs_dgrad: bool,
+        ub_bulk_dgrad: bool,
+        ub_bulk_wgrad: bool,
         ub_name: str,
         fp8_output: bool,
         fsdp_group: Union[dist_group_type, None],
@@ -99,7 +105,8 @@ def forward(
             assert_dim_for_fp8_exec(weight)
 
         tp_world_size = get_distributed_world_size(tp_group)
-        ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
+        ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop
+        ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop
 
         # Cast input to expected dtype
         inputmat = cast_if_needed(inputmat, activation_dtype)
@@ -150,10 +157,11 @@ def forward(
                 inputmat_scale_inv.fill_(inputmat_scale_inv.item())
 
         # Column Parallel Linear
-        if parallel_mode == "column" and sequence_parallel:
+        if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop:
             inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
         else:
             inputmat_total = inputmat
+
         if fp8:
             bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
             bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
@@ -165,75 +173,92 @@ def forward(
             assert isinstance(weight_fp8, Float8Tensor)
 
             if fp8_output:
-                proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
+                out_index, meta_tensor, out_tedtype, out_pttype = (
                     tex.FP8FwdTensors.GEMM1_OUTPUT,
                     fp8_meta["scaling_fwd"],
                     fp8_dtype_forward,
                     torch.uint8,
                 )
             else:
-                proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
+                out_index, meta_tensor, out_tedtype, out_pttype = (
                     None,
                     None,
                     None,
                     activation_dtype,
                 )
 
+            ub_obj = None
             ub_algo = None
             rs_out = None
-            if ub_overlap_rs:
-                ub_obj_projout = get_ub(ub_name + "_fprop")
-                out = ub_obj_projout.get_ubuf_output(1)
+            inputmat_data = (
+                inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total
+            )
+            if ub_overlap_rs_fprop:
+                ub_obj = get_ub(ub_name + "_fprop")
+                out = ub_obj.get_ubuf_output(1)
                 dim_size = list(inputmat_total.size())
                 dim_size[0] = dim_size[0] // tp_world_size
                 dim_size[1] = out_features
                 rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
-                if ub_obj_projout.is_p2p_overlap():
-                    if ub_obj_projout.is_atomic_gemm():
+                if ub_obj.is_p2p_overlap():
+                    if ub_obj.is_atomic_gemm():
                         ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
                     else:
                         ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                 else:
-                    if ub_obj_projout.is_atomic_gemm():
+                    if ub_obj.is_atomic_gemm():
                         ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
                     else:
                         ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
-                if ub_obj_projout.is_fp8_ubuf():
-                    proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
+                if ub_obj.is_fp8_ubuf():
+                    out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
                     meta_tensor = fp8_meta["scaling_fwd"]
-                    proj_out_tetype = fp8_dtype_forward
-                    proj_out_pttype = torch.uint8
-                    ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
+                    out_tedtype = fp8_dtype_forward
+                    out_pttype = torch.uint8
+                    ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
+
+            elif ub_overlap_ag_fprop:
+                ub_obj = get_ub(ub_name + "_fprop")
+                assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer."
+                ub_obj.copy_input_to_ubuf(inputmat_data, True)
+                ub_obj.set_ubuf_scale_inv(inputmat_scale_inv)
+                if ub_obj.is_atomic_gemm():
+                    ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
+                else:
+                    ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
+                out_tedtype = TE_DType[activation_dtype]
+                out_pttype = activation_dtype
+                dim_size = list(inputmat_total.size())
+                dim_size[0] *= tp_size
+                dim_size[1] = out_features
+                out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device)
+
             else:
                 dim_size = list(inputmat_total.size())
                 dim_size[1] = out_features
-                out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
+                out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device)
 
             _ = fp8_gemm(
                 weight_fp8._data,
                 weight_fp8._scale_inv,
                 0,
                 weight_fp8._fp8_dtype,
-                (
-                    inputmat_total._data
-                    if isinstance(inputmat_total, Float8Tensor)
-                    else inputmat_total
-                ),
+                inputmat_data,
                 inputmat_scale_inv,
                 0,
                 fp8_dtype_forward,
-                proj_out_pttype,
+                out_pttype,
                 get_workspace(),
                 bias=bias,
                 use_bias=use_bias,
                 use_split_accumulator=_2X_ACC_FPROP,
                 out=out,
-                ub_algo=ub_algo if ub_overlap_rs else None,
-                ub=ub_obj_projout if ub_overlap_rs else None,
-                extra_output_tensor=rs_out if ub_overlap_rs else None,
-                out_index=proj_out_index,
+                ub_algo=ub_algo,
+                ub=ub_obj,
+                extra_output_tensor=rs_out,
+                out_index=out_index,
                 fp8_meta_tensor=meta_tensor,
-                D_dtype=proj_out_tetype,
+                D_dtype=out_tedtype,
             )
             if fp8_output:
                 out = Float8Tensor(
@@ -261,17 +286,30 @@ def forward(
                     -amin, amax
                 ).float()
 
-            if ub_overlap_rs:
-                ub_obj_projout = get_ub(ub_name + "_fprop")
-                out = ub_obj_projout.get_ubuf_output(1)
+            ub_obj = None
+            ub_algo = None
+            rs_out = None
+            if ub_overlap_rs_fprop:
+                ub_obj = get_ub(ub_name + "_fprop")
+                out = ub_obj.get_ubuf_output(1)
                 dim_size = list(inputmat_total.size())
-                dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
+                dim_size[0] = dim_size[0] // tp_world_size
                 dim_size[1] = out_features
                 rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
-                if ub_obj_projout.is_p2p_overlap():
+                if ub_obj.is_p2p_overlap():
                     ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                 else:
                     ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
+
+            elif ub_overlap_ag_fprop:
+                ub_obj = get_ub(ub_name + "_fprop")
+                ub_obj.copy_input_to_ubuf(inputmat_total, True)
+                dim_size = list(inputmat_total.size())
+                dim_size[0] *= tp_size  # all-gathered sequence length
+                dim_size[1] = out_features
+                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
+                ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
+
             else:
                 dim_size = list(inputmat_total.size())
                 dim_size[1] = out_features
@@ -285,9 +323,9 @@ def forward(
                 bias=bias,
                 use_bias=use_bias,
                 out=out,
-                ub_algo=ub_algo if ub_overlap_rs else None,
-                ub=ub_obj_projout if ub_overlap_rs else None,
-                extra_output_tensor=rs_out if ub_overlap_rs else None,
+                ub_algo=ub_algo,
+                ub=ub_obj,
+                extra_output_tensor=rs_out,
             )
 
         if is_grad_enabled:
@@ -343,7 +381,10 @@ def forward(
             ctx.inp_shape = inp_shape
             ctx.parallel_mode = parallel_mode
             ctx.tp_group = tp_group
-            ctx.ub_overlap_ag = ub_overlap_ag
+            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
+            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
+            ctx.ub_bulk_dgrad = ub_bulk_dgrad
+            ctx.ub_bulk_wgrad = ub_bulk_wgrad
             ctx.ub_name = ub_name
             ctx.tp_size = tp_size
             ctx.requires_dgrad = inp.requires_grad
@@ -356,12 +397,13 @@ def forward(
                     FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
 
         # Row Parallel Linear
-        if ub_overlap_rs:
-            out = rs_out
-        elif parallel_mode == "row" and sequence_parallel:
-            out, _ = reduce_scatter_along_first_dim(out, tp_group)
-        elif parallel_mode == "row" and tensor_parallel:
-            out, _ = allreduce(out, tp_group)
+        if parallel_mode == "row":
+            if ub_overlap_rs_fprop:
+                out = rs_out
+            elif sequence_parallel:
+                out, _ = reduce_scatter_along_first_dim(out, tp_group)
+            elif tensor_parallel:
+                out, _ = allreduce(out, tp_group)
 
         # [*, in_features] -> [*, out_features] except first dimension changes for SP
         return out.view(-1, *inp_shape[1:-1], out_features)
@@ -401,15 +443,75 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
 
             tp_world_size = get_distributed_world_size(ctx.tp_group)
             ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
-            ub_algo = None
+            ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad
+            ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad
+            ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad
+
+            ctx.ub_obj_gradout = None
+            ub_obj_wgrad = None
+            ub_algo_wgrad = None
+            ub_algo_dgrad = None
+            rs_out = None
+            dgrad = None
+            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
             if ctx.ub_overlap_ag:
-                dim_size = list(grad_output.size())
-                dim_size[0] = dim_size[0] * tp_world_size
+                # Overlap grad_output all-gather with dgrad compute
                 ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
                 if ctx.ub_obj_gradout.is_atomic_gemm():
-                    ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
+                    ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
                 else:
-                    ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
+                    ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
+                dgrad = torch.empty(
+                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
+                )
+
+            elif ctx.ub_overlap_rs_dgrad:
+                # Overlap dgrad reduce-scatter with dgrad compute
+                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
+                dgrad = ctx.ub_obj_gradout.get_ubuf_output(1)
+                if ctx.ub_obj_gradout.is_p2p_overlap():
+                    if ctx.ub_obj_gradout.is_atomic_gemm():
+                        ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
+                    else:
+                        ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
+                else:
+                    if ctx.ub_obj_gradout.is_atomic_gemm():
+                        ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
+                    else:
+                        ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
+                rs_out = torch.empty(
+                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
+                )
+                ctx.ub_bulk_dgrad = False
+                ctx.ub_bulk_wgrad = False
+
+            else:
+                if ctx.ub_bulk_dgrad:
+                    # Overlap inputmat all-gather with dgrad compute
+                    ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG
+                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
+                    inputmat_data = (
+                        inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat
+                    )
+                    ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True)
+                    inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1)
+                    if isinstance(inputmat, Float8Tensor):
+                        inputmat._data = inputmat_ubuf
+                    else:
+                        inputmat = inputmat_ubuf
+
+                if ctx.ub_bulk_wgrad:
+                    # Overlap dgrad reduce-scatter with wgrad compute
+                    ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS
+                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
+                    dgrad = ub_obj_wgrad.get_ubuf_output(1)
+
+            if dgrad is None:
+                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
+                    dgrad_shape[0] = dgrad_shape[0] * tp_world_size
+                dgrad = torch.empty(
+                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
+                )
 
             (
                 grad_output,
@@ -420,13 +522,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                 ctx, grad_output, ctx.parallel_mode == "row"
             )
 
-            # Column Parallel Linear
-            # Overlap input AG with dgrad
+            # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers)
             inputmat_total = None
             inputmat_t_total = None
-            handle = None
-            if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
-                inputmat_total, handle = gather_along_first_dim(
+            inputmat_gather_handle = None
+            if (
+                weight.requires_grad
+                and ctx.parallel_mode == "column"
+                and ctx.sequence_parallel
+                and not ctx.ub_bulk_dgrad
+            ):
+                inputmat_total, inputmat_gather_handle = gather_along_first_dim(
                     inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                 )
             else:
@@ -446,13 +552,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
 
             if ctx.requires_dgrad:
                 if ctx.fp8:
-                    if ctx.is_input_fp8:
+                    if ctx.is_input_fp8 or (
+                        ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf()
+                    ):
                         out_index, meta_tensor, output_te_dtype, output_dtype = (
                             tex.FP8BwdTensors.GRAD_INPUT1,
                             ctx.fp8_meta["scaling_bwd"],
                             fp8_dtype_backward,
                             torch.uint8,
                         )
+                        if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf():
+                            ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
                     else:
                         out_index, meta_tensor, output_te_dtype, output_dtype = (
                             None,
@@ -460,7 +570,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                             None,
                             ctx.activation_dtype,
                         )
-                    dgrad, _ = fp8_gemm(
+                    _ = fp8_gemm(
                         weight_fp8.transpose_2d(),
                         weight_fp8._scale_inv,
                         0,
@@ -472,12 +582,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                         output_dtype,
                         get_workspace(),
                         use_split_accumulator=_2X_ACC_DGRAD,
-                        ub_algo=ub_algo if ctx.ub_overlap_ag else None,
-                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
+                        ub_algo=ub_algo_dgrad,
+                        ub=ctx.ub_obj_gradout,
+                        out=dgrad,
                         out_index=out_index,
                         fp8_meta_tensor=meta_tensor,
                         D_dtype=output_te_dtype,
+                        extra_output_tensor=rs_out,
                     )
+
+                    if ctx.ub_overlap_rs_dgrad:
+                        dgrad = rs_out
+
                     if output_dtype == torch.uint8:
                         dgrad = Float8Tensor(
                             data=dgrad,
@@ -488,30 +604,34 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                             dtype=ctx.activation_dtype,
                         )
                 else:
-                    dgrad, _, _ = gemm(
+                    _ = gemm(
                         weight,
                         grad_output,
                         ctx.activation_dtype,
                         get_workspace(),
                         layout="NN",
                         grad=True,
-                        ub_algo=(
-                            tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
-                            if ctx.ub_overlap_ag
-                            else None
-                        ),
-                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
+                        ub_algo=ub_algo_dgrad,
+                        ub=ctx.ub_obj_gradout,
+                        out=dgrad,
+                        extra_output_tensor=rs_out,
                     )
 
-                # Overlap dgrad-RS/AR with wgrad
-                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
-                    if handle is not None:
-                        handle.wait()
-                    dgrad, handle = reduce_scatter_along_first_dim(
+                    if ctx.ub_overlap_rs_dgrad:
+                        dgrad = rs_out
+
+            if inputmat_gather_handle is not None:
+                inputmat_gather_handle.wait()
+
+            # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers)
+            dgrad_reduce_handle = None
+            if ctx.requires_dgrad and ctx.parallel_mode == "column":
+                if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad):
+                    dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim(
                         dgrad, ctx.tp_group, async_op=True
                     )
-                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
-                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
+                elif ctx.tensor_parallel and not ctx.sequence_parallel:
+                    dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True)
 
             wgrad = None
             if weight.requires_grad:
@@ -548,6 +668,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                             accumulate=accumulate_wgrad_into_param_main_grad,
                             out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                             use_split_accumulator=_2X_ACC_WGRAD,
+                            ub=ub_obj_wgrad,
+                            ub_algo=ub_algo_wgrad,
                         )
                     else:
                         wgrad, _, _ = gemm(
@@ -559,6 +681,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                             grad=True,
                             accumulate=accumulate_wgrad_into_param_main_grad,
                             out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
+                            ub=ub_obj_wgrad,
+                            ub_algo=ub_algo_wgrad,
                         )
                 else:
                     # WGRAD
@@ -572,15 +696,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
                         use_bias=ctx.use_bias,
                         accumulate=accumulate_wgrad_into_param_main_grad,
                         out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
+                        ub=ub_obj_wgrad,
+                        ub_algo=ub_algo_wgrad,
                     )
 
+                if ctx.ub_bulk_wgrad:
+                    dgrad = ub_obj_wgrad.get_ubuf_output(0)
+
                 # Deallocate input tensor
                 clear_tensor_data(inputmat_total)
                 clear_tensor_data(inputmat_t_total)
 
-            # Column Parallel Linear
-            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
-                handle.wait()
+            # Wait for dgrad reduce-scatter or all-reduce
+            if dgrad_reduce_handle is not None:
+                dgrad_reduce_handle.wait()
 
             if not ctx.use_bias:
                 grad_bias = None
@@ -634,8 +763,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
             None,  # activation_dtype
             None,  # parallel_mode
             None,  # is_grad_enabled
-            None,  # ub_overlap_rs
-            None,  # ub_overlap_ag
+            None,  # ub_overlap_rs_fprop
+            None,  # ub_overlap_ag_dgrad
+            None,  # ub_overlap_ag_fprop
+            None,  # ub_overlap_rs_dgrad
+            None,  # ub_bulk_dgrad
+            None,  # ub_bulk_wgrad
             None,  # ub_name
             None,  # fp8_output
             None,  # fsdp_group
@@ -729,8 +862,10 @@ def __init__(
         parallel_mode: Optional[str] = None,
         parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
         device: Union[torch.device, str] = "cuda",
-        ub_overlap_rs: bool = False,
         ub_overlap_ag: bool = False,
+        ub_overlap_rs: bool = False,
+        ub_bulk_dgrad: bool = False,
+        ub_bulk_wgrad: bool = False,
         ub_name: Optional[str] = None,
     ) -> None:
         super().__init__()
@@ -742,13 +877,6 @@ def __init__(
         self.use_bias = bias
         self.return_bias = return_bias
         self.apply_bias = bias and not return_bias
-        self.ub_overlap_rs = ub_overlap_rs
-        self.ub_overlap_ag = ub_overlap_ag
-        if ub_overlap_rs or ub_overlap_ag:
-            assert ub_name is not None, "Userbuffer name [string] is not set."
-        self.ub_name = ub_name
-        self.get_rng_state_tracker = get_rng_state_tracker
-        self.rng_tracker_name = rng_tracker_name
 
         if device == "meta":
             assert parameters_split is None, "Cannot split module parameters on 'meta' device."
@@ -773,6 +901,45 @@ def __init__(
 
         self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
 
+        # Column parallel TP overlap options
+        self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag
+        self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs
+        self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad
+        self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad
+        if self.ub_overlap_rs_dgrad:
+            self.ub_bulk_dgrad = False
+            self.ub_bulk_wgrad = False
+
+        # Row parallel TP overlap options
+        self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs
+        self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag
+
+        if any(
+            [
+                self.ub_overlap_rs_fprop,
+                self.ub_overlap_ag_dgrad,
+                self.ub_overlap_ag_fprop,
+                self.ub_overlap_rs_dgrad,
+                self.ub_bulk_dgrad,
+                self.ub_bulk_wgrad,
+            ]
+        ):
+            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
+        self.ub_name = ub_name
+
+        assert not (
+            self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop
+        ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time."
+        assert not (
+            self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad
+        ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time."
+        assert not (
+            self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad)
+        ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time."
+
+        self.get_rng_state_tracker = get_rng_state_tracker
+        self.rng_tracker_name = rng_tracker_name
+
         # Initialize params in FP8
         with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
 
@@ -1017,8 +1184,12 @@ def forward(
                 self.activation_dtype,
                 self.parallel_mode,
                 torch.is_grad_enabled(),
-                self.ub_overlap_rs,
-                self.ub_overlap_ag,
+                self.ub_overlap_rs_fprop,
+                self.ub_overlap_ag_dgrad,
+                self.ub_overlap_ag_fprop,
+                self.ub_overlap_rs_dgrad,
+                self.ub_bulk_dgrad,
+                self.ub_bulk_wgrad,
                 self.ub_name,
                 fp8_output,
                 self.fsdp_group,