File tree Expand file tree Collapse file tree 2 files changed +5
-13
lines changed
transformer_engine/pytorch/module Expand file tree Collapse file tree 2 files changed +5
-13
lines changed Original file line number Diff line number Diff line change @@ -895,22 +895,13 @@ def wgrad_gemm(
895895 del grad_bias_
896896
897897 # Deallocate input tensor if permitted
898- if (
899- not ctx .return_layernorm_output
900- and not ctx .return_layernorm_output_gathered
901- ):
898+ if not ctx .return_layernorm_output and not ctx .return_layernorm_output_gathered :
902899 # Do not need to return layernorm output
903900 clear_tensor_data (ln_out )
904- elif (
905- ctx .return_layernorm_output_gathered
906- and ctx .ln_out_needs_gather
907- ):
901+ elif ctx .return_layernorm_output_gathered and ctx .ln_out_needs_gather :
908902 # ln_out is not the returned tensor
909903 clear_tensor_data (ln_out )
910- if (
911- ctx .ln_out_needs_gather
912- and not ctx .ub_bulk_dgrad
913- ):
904+ if ctx .ln_out_needs_gather and not ctx .ub_bulk_dgrad :
914905 clear_tensor_data (ln_out_total )
915906
916907 # Update grad input if overlapping reduce-scatter with wgrad GEMM
Original file line number Diff line number Diff line change @@ -890,7 +890,8 @@ def wgrad_gemm(
890890 clear_tensor_data (inputmat_total )
891891
892892 if (
893- ctx .parallel_mode == "row" and ctx .sequence_parallel
893+ ctx .parallel_mode == "row"
894+ and ctx .sequence_parallel
894895 and not ctx .ub_overlap_ag
895896 ):
896897 clear_tensor_data (grad_output )
You can’t perform that action at this time.
0 commit comments