Skip to content

Commit

Permalink
[PyTorch] Adding TP overlap support for te.Linear with `parallel_mo…
Browse files Browse the repository at this point in the history
…de="column"` (#1343)

* support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward

Signed-off-by: Alp Dener <[email protected]>

* implemented TP overlap support for column-parallel te.Linear

Signed-off-by: Alp Dener <[email protected]>

* fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* improved error messages for internal failure to infer TP overlap options in te.Linear

Signed-off-by: Alp Dener <[email protected]>

* fixed linting errors

Signed-off-by: Alp Dener <[email protected]>

* fixed incorrect TP overlap option asserts

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
denera and pre-commit-ci[bot] authored Jan 13, 2025
1 parent cbc4653 commit 2402406
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 111 deletions.
62 changes: 46 additions & 16 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
kwargs["ub_overlap_ag"] = not reference

if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
if config.linear_parallel_mode == "row":
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["ub_overlap_rs"] = not reference
elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
args.append(3 * hidden_size)
kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
Expand Down Expand Up @@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--linear-parallel-mode",
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.",
)
parser.add_argument(
"--debug",
action="store_true",
Expand Down Expand Up @@ -230,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")

# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"proj_dgrad": {"method": "ring_exchange"},
"qkv_dgrad": {"method": "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs,
)

# Initialize the Transformer Engine layer with overlap
Expand Down Expand Up @@ -314,27 +342,29 @@ def run_fwd_bwd(model, x):
ref_grads.append(ref_param.grad)

# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
num_grads_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world)

# Now validate accuracy
if not bool(numerics_failed.item()):
numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda")
if not bool(num_grads_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
numerics_failed[i] = int(grad_failed)
return_code = torch.max(numerics_failed)
dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world)
else:
return_code = num_grads_failed

te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
Expand All @@ -344,7 +374,7 @@ def run_fwd_bwd(model, x):
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)

return numerics_failed[0].item()
return return_code.item()


if __name__ == "__main__":
Expand Down
24 changes: 17 additions & 7 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
BATCH_SIZE: int = 2
NUM_HEADS: int = 12
HEAD_DIM: int = 64

# NOTE: te.Linear is intentionally omitted here and manually added later for testing both
# row and column parallel layouts.
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
Expand Down Expand Up @@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg
raise AssertionError(result.stderr.decode())


def _run_layer_with_overlap(layer_type, fp8, fp8_init):
def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
Expand All @@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init):
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if layer_type == te.Linear.__name__:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")

if fp8:
if not fp8_available:
Expand Down Expand Up @@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections):


@pytest.mark.parametrize(
"layer_type",
[layer.__name__ for layer in TE_LAYERS],
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
"layer_type,linear_parallel_mode",
(
[(te.Linear.__name__, "row"), (te.Linear.__name__, "column")]
+ list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
),
ids=(
[f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "]
+ [(" " + layer.__name__ + " ") for layer in TE_LAYERS]
),
)
@pytest.mark.parametrize(
"fp8,fp8_init",
Expand All @@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" FP8 GEMM - FP8 PARAMS ",
],
)
def test_layers_with_overlap(layer_type, fp8, fp8_init):
def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
_run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init)
Loading

0 comments on commit 2402406

Please sign in to comment.