From 240240617267cff76178a7f5da58a93806e5a6d2 Mon Sep 17 00:00:00 2001 From: Alp Dener <adener@nvidia.com> Date: Mon, 13 Jan 2025 14:24:08 -0600 Subject: [PATCH] [PyTorch] Adding TP overlap support for `te.Linear` with `parallel_mode="column"` (#1343) * support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward Signed-off-by: Alp Dener <adener@nvidia.com> * implemented TP overlap support for column-parallel te.Linear Signed-off-by: Alp Dener <adener@nvidia.com> * fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests Signed-off-by: Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * improved error messages for internal failure to infer TP overlap options in te.Linear Signed-off-by: Alp Dener <adener@nvidia.com> * fixed linting errors Signed-off-by: Alp Dener <adener@nvidia.com> * fixed incorrect TP overlap option asserts Signed-off-by: Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener <adener@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../distributed/run_layer_with_overlap.py | 62 +++- .../distributed/test_comm_gemm_overlap.py | 24 +- transformer_engine/pytorch/module/linear.py | 347 +++++++++++++----- 3 files changed, 322 insertions(+), 111 deletions(-) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e49174c24f..5a67bd616a 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): kwargs["ub_overlap_ag"] = not reference if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = "row" - kwargs["ub_overlap_rs"] = not reference - kwargs["ub_name"] = "proj" + if config.linear_parallel_mode == "row": + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["ub_overlap_rs"] = not reference + elif config.linear_parallel_mode == "column": + input_shape[0] = config.seq_length // tp_size + args.append(3 * hidden_size) + kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["parallel_mode"] = config.linear_parallel_mode + kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv" else: input_shape[0] = config.seq_length // tp_size - kwargs["ub_bulk_wgrad"] = not reference - kwargs["ub_bulk_dgrad"] = not reference + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference if config.layer_type is te.LayerNormLinear: args.append(3 * hidden_size) kwargs["parallel_mode"] = "column" @@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." ) + parser.add_argument( + "--linear-parallel-mode", + type=str.lower, + default="row", + choices=["row", "column"], + help="Parallel mode for te.Linear.", + ) + parser.add_argument( + "--overlap-rs-dgrad", + action="store_true", + default=False, + help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.", + ) parser.add_argument( "--debug", action="store_true", @@ -230,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Intialize userbuffers + ub_cfgs = None + if opts.overlap_rs_dgrad: + ub_cfgs = { + "proj_dgrad": {"method": "ring_exchange"}, + "qkv_dgrad": {"method": "ring_exchange"}, + } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], WORLD_SIZE, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=ub_cfgs, ) # Initialize the Transformer Engine layer with overlap @@ -314,27 +342,29 @@ def run_fwd_bwd(model, x): ref_grads.append(ref_param.grad) # Make sure we have the same number of gradients - numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") if len(test_grads) != len(ref_grads): - numerics_failed[0] = 1 + num_grads_failed[0] = 1 numerics_info = ( "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + f"expected {len(ref_grads)} but got {len(test_grads)}." ) dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world) # Now validate accuracy - if not bool(numerics_failed.item()): + numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda") + if not bool(num_grads_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): rtol = 0.125 if opts.fp8 else 0.025 atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[0] = int(grad_failed) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()): - break + numerics_failed[i] = int(grad_failed) + return_code = torch.max(numerics_failed) + dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world) + else: + return_code = num_grads_failed te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) @@ -344,7 +374,7 @@ def run_fwd_bwd(model, x): if opts.debug and WORLD_RANK == 0: print("Exiting...\n", end="", flush=True) - return numerics_failed[0].item() + return return_code.item() if __name__ == "__main__": diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 240e396534..c285da7fbd 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -21,8 +21,10 @@ BATCH_SIZE: int = 2 NUM_HEADS: int = 12 HEAD_DIM: int = 64 + +# NOTE: te.Linear is intentionally omitted here and manually added later for testing both +# row and column parallel layouts. TE_LAYERS = [ - te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, @@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] + if layer_type == te.Linear.__name__: + test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") if fp8: if not fp8_available: @@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections): @pytest.mark.parametrize( - "layer_type", - [layer.__name__ for layer in TE_LAYERS], - ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], + "layer_type,linear_parallel_mode", + ( + [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")] + + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) + ), + ids=( + [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "] + + [(" " + layer.__name__ + " ") for layer in TE_LAYERS] + ), ) @pytest.mark.parametrize( "fp8,fp8_init", @@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections): " FP8 GEMM - FP8 PARAMS ", ], ) -def test_layers_with_overlap(layer_type, fp8, fp8_init): +def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5fd4dd2fc9..2262d23832 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Linear API""" +from functools import reduce +from operator import mul as multiply_op from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -43,7 +45,7 @@ fp8_cast_transpose_fused, cast_to_fp8, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor @@ -80,8 +82,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, ub_name: str, fp8_output: bool, fsdp_group: Union[dist_group_type, None], @@ -99,7 +105,8 @@ def forward( assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs + ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop + ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop # Cast input to expected dtype inputmat = cast_if_needed(inputmat, activation_dtype) @@ -150,10 +157,11 @@ def forward( inputmat_scale_inv.fill_(inputmat_scale_inv.item()) # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: + if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat + if fp8: bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -165,75 +173,92 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], fp8_dtype_forward, torch.uint8, ) else: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( None, None, None, activation_dtype, ) + ub_obj = None ub_algo = None rs_out = None - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) + inputmat_data = ( + inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total + ) + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + out = ub_obj.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): - if ub_obj_projout.is_atomic_gemm(): + if ub_obj.is_p2p_overlap(): + if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - if ub_obj_projout.is_atomic_gemm(): + if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj_projout.is_fp8_ubuf(): - proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT + if ub_obj.is_fp8_ubuf(): + out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] - proj_out_tetype = fp8_dtype_forward - proj_out_pttype = torch.uint8 - ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) + out_tedtype = fp8_dtype_forward + out_pttype = torch.uint8 + ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer." + ub_obj.copy_input_to_ubuf(inputmat_data, True) + ub_obj.set_ubuf_scale_inv(inputmat_scale_inv) + if ub_obj.is_atomic_gemm(): + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + out_tedtype = TE_DType[activation_dtype] + out_pttype = activation_dtype + dim_size = list(inputmat_total.size()) + dim_size[0] *= tp_size + dim_size[1] = out_features + out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) + else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features - out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) + out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) _ = fp8_gemm( weight_fp8._data, weight_fp8._scale_inv, 0, weight_fp8._fp8_dtype, - ( - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) - else inputmat_total - ), + inputmat_data, inputmat_scale_inv, 0, fp8_dtype_forward, - proj_out_pttype, + out_pttype, get_workspace(), bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=proj_out_index, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out_index=out_index, fp8_meta_tensor=meta_tensor, - D_dtype=proj_out_tetype, + D_dtype=out_tedtype, ) if fp8_output: out = Float8Tensor( @@ -261,17 +286,30 @@ def forward( -amin, amax ).float() - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) + ub_obj = None + ub_algo = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + out = ub_obj.get_ubuf_output(1) dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) + dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): + if ub_obj.is_p2p_overlap(): ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_obj.copy_input_to_ubuf(inputmat_total, True) + dim_size = list(inputmat_total.size()) + dim_size[0] *= tp_size # all-gathered sequence length + dim_size[1] = out_features + out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features @@ -285,9 +323,9 @@ def forward( bias=bias, use_bias=use_bias, out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, ) if is_grad_enabled: @@ -343,7 +381,10 @@ def forward( ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad @@ -356,12 +397,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if ub_overlap_rs: - out = rs_out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if parallel_mode == "row": + if ub_overlap_rs_fprop: + out = rs_out + elif sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp_shape[1:-1], out_features) @@ -401,15 +443,75 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], tp_world_size = get_distributed_world_size(ctx.tp_group) ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ub_algo = None + ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad + + ctx.ub_obj_gradout = None + ub_obj_wgrad = None + ub_algo_wgrad = None + ub_algo_dgrad = None + rs_out = None + dgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] * tp_world_size + # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + dgrad = ctx.ub_obj_gradout.get_ubuf_output(1) + if ctx.ub_obj_gradout.is_p2p_overlap(): + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + inputmat_data = ( + inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat + ) + ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) + inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) + if isinstance(inputmat, Float8Tensor): + inputmat._data = inputmat_ubuf + else: + inputmat = inputmat_ubuf + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + dgrad = ub_obj_wgrad.get_ubuf_output(1) + + if dgrad is None: + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + dgrad_shape[0] = dgrad_shape[0] * tp_world_size + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) ( grad_output, @@ -420,13 +522,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx, grad_output, ctx.parallel_mode == "row" ) - # Column Parallel Linear - # Overlap input AG with dgrad + # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers) inputmat_total = None inputmat_t_total = None - handle = None - if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel: - inputmat_total, handle = gather_along_first_dim( + inputmat_gather_handle = None + if ( + weight.requires_grad + and ctx.parallel_mode == "column" + and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad + ): + inputmat_total, inputmat_gather_handle = gather_along_first_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad ) else: @@ -446,13 +552,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - if ctx.is_input_fp8: + if ctx.is_input_fp8 or ( + ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf() + ): out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, ctx.fp8_meta["scaling_bwd"], fp8_dtype_backward, torch.uint8, ) + if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf(): + ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: out_index, meta_tensor, output_te_dtype, output_dtype = ( None, @@ -460,7 +570,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, ctx.activation_dtype, ) - dgrad, _ = fp8_gemm( + _ = fp8_gemm( weight_fp8.transpose_2d(), weight_fp8._scale_inv, 0, @@ -472,12 +582,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], output_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub_algo=ub_algo_dgrad, + ub=ctx.ub_obj_gradout, + out=dgrad, out_index=out_index, fp8_meta_tensor=meta_tensor, D_dtype=output_te_dtype, + extra_output_tensor=rs_out, ) + + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + if output_dtype == torch.uint8: dgrad = Float8Tensor( data=dgrad, @@ -488,30 +604,34 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - dgrad, _, _ = gemm( + _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), layout="NN", grad=True, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - if ctx.ub_overlap_ag - else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub_algo=ub_algo_dgrad, + ub=ctx.ub_obj_gradout, + out=dgrad, + extra_output_tensor=rs_out, ) - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + + if inputmat_gather_handle is not None: + inputmat_gather_handle.wait() + + # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers) + dgrad_reduce_handle = None + if ctx.requires_dgrad and ctx.parallel_mode == "column": + if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad): + dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim( dgrad, ctx.tp_group, async_op=True ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel and not ctx.sequence_parallel: + dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) wgrad = None if weight.requires_grad: @@ -548,6 +668,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: wgrad, _, _ = gemm( @@ -559,6 +681,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: # WGRAD @@ -572,15 +696,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_ubuf_output(0) + # Deallocate input tensor clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_t_total) - # Column Parallel Linear - if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: - handle.wait() + # Wait for dgrad reduce-scatter or all-reduce + if dgrad_reduce_handle is not None: + dgrad_reduce_handle.wait() if not ctx.use_bias: grad_bias = None @@ -634,8 +763,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # ub_overlap_rs - None, # ub_overlap_ag + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fp8_output None, # fsdp_group @@ -729,8 +862,10 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -742,13 +877,6 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag - if ub_overlap_rs or ub_overlap_ag: - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name - self.get_rng_state_tracker = get_rng_state_tracker - self.rng_tracker_name = rng_tracker_name if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -773,6 +901,45 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag + self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs + self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad + self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad + if self.ub_overlap_rs_dgrad: + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs + self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." + self.ub_name = ub_name + + assert not ( + self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop + ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." + assert not ( + self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad + ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." + assert not ( + self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) + ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time." + + self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1017,8 +1184,12 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - self.ub_overlap_rs, - self.ub_overlap_ag, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.ub_name, fp8_output, self.fsdp_group,