Skip to content

Commit

Permalink
Add NVTX ranges to categorize execution (#1447)
Browse files Browse the repository at this point in the history
Signed-off-by: Jaemin Choi <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2025
1 parent b87e539 commit 49a4535
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 6 deletions.
18 changes: 17 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
nvtx_range_push,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
fused_attn_bwd,
Expand Down Expand Up @@ -1834,6 +1838,7 @@ def forward(
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -2756,12 +2761,14 @@ def forward(
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")

return out_ret

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a

Expand Down Expand Up @@ -3602,6 +3609,7 @@ def backward(ctx, dout):
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")

return (
None,
Expand Down Expand Up @@ -3688,6 +3696,7 @@ def forward(
cp_stream,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -3904,11 +3913,13 @@ def forward(
ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
return out

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)

Expand Down Expand Up @@ -4092,6 +4103,7 @@ def backward(ctx, dout):
dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
dk = dk.movedim(0, seq_dim).contiguous()
dv = dv.movedim(0, seq_dim).contiguous()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")

return (
None,
Expand Down Expand Up @@ -4151,6 +4163,7 @@ def forward(
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -4403,11 +4416,13 @@ def forward(
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
cp_size = get_distributed_world_size(ctx.cp_group)

(
Expand Down Expand Up @@ -4592,6 +4607,7 @@ def backward(ctx, dout):
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")

return (
None,
Expand Down
44 changes: 41 additions & 3 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
)
from ..fp8 import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
cast_if_needed,
clear_tensor_data,
divide,
get_default_init_method,
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
)
from ..distributed import (
Expand Down Expand Up @@ -112,6 +114,12 @@ def forward(
skip_fp8_weight_update: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring

# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"

# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
Expand All @@ -121,10 +129,12 @@ def forward(
assert_dim_for_fp8_exec(inputmat, weight)

# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")

tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
Expand Down Expand Up @@ -175,6 +185,7 @@ def forward(
)

# Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,
Expand All @@ -188,9 +199,11 @@ def forward(
zero_centered_gamma,
)
ln_out_return = ln_out if return_layernorm_output else None
nvtx_range_pop(f"{nvtx_label}.norm")

# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop:
with_quantized_all_gather = fp8
if return_layernorm_output and return_layernorm_output_gathered:
Expand All @@ -217,6 +230,7 @@ def forward(
elif backward_needs_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")

# Cast weight to expected dtype
weightmat = weight
Expand Down Expand Up @@ -275,6 +289,7 @@ def forward(
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)

nvtx_range_push(f"{nvtx_label}.gemm")
out, *_, rs_out = general_gemm(
weightmat,
ln_out_total,
Expand All @@ -287,6 +302,8 @@ def forward(
ub_type=ub_type,
extra_output=rs_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")

if not weight.requires_grad:
if not return_layernorm_output:
ln_out = ln_out_total = None
Expand All @@ -307,6 +324,7 @@ def forward(
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
Expand All @@ -315,6 +333,7 @@ def forward(
weightmat if quantized_weight else None,
ln_out if weight.requires_grad else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
Expand Down Expand Up @@ -372,10 +391,12 @@ def forward(
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")

# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
Expand All @@ -394,6 +415,11 @@ def backward(
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring

# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.backward"
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
Expand Down Expand Up @@ -433,6 +459,7 @@ def backward(
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_gather")
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
Expand All @@ -441,6 +468,7 @@ def backward(
weight if ctx.fp8 and ctx.quantized_weight else None,
ln_out,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")

# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
Expand Down Expand Up @@ -515,12 +543,14 @@ def backward(
if ctx.fp8:
quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
ln_out_total = ln_out

Expand All @@ -536,6 +566,7 @@ def backward(
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad, *_ = general_gemm(
weight,
grad_output,
Expand All @@ -551,12 +582,14 @@ def backward(
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")

# Launch tensor-parallel communication
dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
Expand All @@ -567,6 +600,7 @@ def backward(
)
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")

# Compute grad weight tensor
wgrad = None
Expand Down Expand Up @@ -603,6 +637,7 @@ def backward(

# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total,
grad_output,
Expand All @@ -621,6 +656,7 @@ def backward(
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")

if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
Expand Down Expand Up @@ -657,6 +693,7 @@ def backward(
# Norm gradient
dgamma = None
dbeta = None
nvtx_range_push(f"{nvtx_label}.norm")
if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad,
Expand All @@ -679,6 +716,7 @@ def backward(
)
dgrad = dgrad.reshape(inputmat.size())
dbeta = None
nvtx_range_pop(f"{nvtx_label}.norm")
clear_tensor_data(mu)
clear_tensor_data(rsigma)

Expand Down
Loading

0 comments on commit 49a4535

Please sign in to comment.