Skip to content

Commit

Permalink
[PyTorch] te.Linear FP8 DGRAD+RS output bugfix (#1412)
Browse files Browse the repository at this point in the history
* corrected RS overlap BF16 output clashing with Float8Tensor constructor

Signed-off-by: Alp Dener <[email protected]>

* fixed empty dgrad buffer dtype at initialization

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
denera and pre-commit-ci[bot] authored Jan 16, 2025
1 parent 3d63cbb commit c2937c5
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
dgrad = ub_obj_wgrad.get_ubuf_output(1)

if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)

(
grad_output,
grad_output_c,
Expand Down Expand Up @@ -550,6 +543,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)

output_dtype = ctx.activation_dtype
if ctx.requires_dgrad:
if ctx.fp8:
if ctx.is_input_fp8 or (
Expand All @@ -570,6 +564,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
None,
ctx.activation_dtype,
)

if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device)

if ctx.requires_dgrad:
if ctx.fp8:
_ = fp8_gemm(
weight_fp8.transpose_2d(),
weight_fp8._scale_inv,
Expand All @@ -593,8 +595,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out

if output_dtype == torch.uint8:
elif output_dtype == torch.uint8:
dgrad = Float8Tensor(
data=dgrad,
fp8_meta=ctx.fp8_meta,
Expand Down

0 comments on commit c2937c5

Please sign in to comment.