Skip to content

Commit 8509ef6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a9f2eb9 commit 8509ef6

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -895,22 +895,13 @@ def wgrad_gemm(
895895
del grad_bias_
896896

897897
# Deallocate input tensor if permitted
898-
if (
899-
not ctx.return_layernorm_output
900-
and not ctx.return_layernorm_output_gathered
901-
):
898+
if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
902899
# Do not need to return layernorm output
903900
clear_tensor_data(ln_out)
904-
elif (
905-
ctx.return_layernorm_output_gathered
906-
and ctx.ln_out_needs_gather
907-
):
901+
elif ctx.return_layernorm_output_gathered and ctx.ln_out_needs_gather:
908902
# ln_out is not the returned tensor
909903
clear_tensor_data(ln_out)
910-
if (
911-
ctx.ln_out_needs_gather
912-
and not ctx.ub_bulk_dgrad
913-
):
904+
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
914905
clear_tensor_data(ln_out_total)
915906

916907
# Update grad input if overlapping reduce-scatter with wgrad GEMM

transformer_engine/pytorch/module/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,8 @@ def wgrad_gemm(
890890
clear_tensor_data(inputmat_total)
891891

892892
if (
893-
ctx.parallel_mode == "row" and ctx.sequence_parallel
893+
ctx.parallel_mode == "row"
894+
and ctx.sequence_parallel
894895
and not ctx.ub_overlap_ag
895896
):
896897
clear_tensor_data(grad_output)

0 commit comments

Comments
 (0)