Skip to content

Commit a9f2eb9

Browse files
fix memory overhead of all gather from sequence parallel
1 parent 8dba296 commit a9f2eb9

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,11 @@ def forward(
353353

354354
# Deallocate GEMM input tensor if no longer needed
355355
if not weight.requires_grad and not return_layernorm_output:
356-
ln_out = ln_out_total = None
357356
clear_tensor_data(ln_out, ln_out_total)
357+
ln_out = ln_out_total = None
358+
elif ln_out_total is not ln_out_return and not ub_overlap_ag_fprop:
359+
clear_tensor_data(ln_out_total)
360+
ln_out_total = None
358361

359362
# ------------------------------------------------------
360363
# Prepare output tensor
@@ -892,7 +895,22 @@ def wgrad_gemm(
892895
del grad_bias_
893896

894897
# Deallocate input tensor if permitted
895-
if not ctx.return_layernorm_output:
898+
if (
899+
not ctx.return_layernorm_output
900+
and not ctx.return_layernorm_output_gathered
901+
):
902+
# Do not need to return layernorm output
903+
clear_tensor_data(ln_out)
904+
elif (
905+
ctx.return_layernorm_output_gathered
906+
and ctx.ln_out_needs_gather
907+
):
908+
# ln_out is not the returned tensor
909+
clear_tensor_data(ln_out)
910+
if (
911+
ctx.ln_out_needs_gather
912+
and not ctx.ub_bulk_dgrad
913+
):
896914
clear_tensor_data(ln_out_total)
897915

898916
# Update grad input if overlapping reduce-scatter with wgrad GEMM

transformer_engine/pytorch/module/linear.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ def forward(
317317
# Finished forward GEMM...
318318
# ------------------------------------------------------
319319

320+
# Deallocate GEMM input tensor if no longer needed
321+
if with_input_all_gather_nccl:
322+
clear_tensor_data(inputmat_total)
323+
inputmat_total = None
324+
320325
# ------------------------------------------------------
321326
# Prepare output tensor
322327
# Note: Perform tensor-parallel communication
@@ -881,6 +886,14 @@ def wgrad_gemm(
881886
# Deallocate input tensor if permitted
882887
if ctx.owns_input:
883888
clear_tensor_data(inputmat_total)
889+
elif ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
890+
clear_tensor_data(inputmat_total)
891+
892+
if (
893+
ctx.parallel_mode == "row" and ctx.sequence_parallel
894+
and not ctx.ub_overlap_ag
895+
):
896+
clear_tensor_data(grad_output)
884897

885898
# Update grad input if overlapping reduce-scatter with wgrad GEMM
886899
if ctx.ub_bulk_wgrad:

transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,12 @@ def _create_columnwise(self):
349349
def _transpose_columnwise_data(self):
350350
"""Plainly transpose the columnwise data and scale inv."""
351351
if self._columnwise_data is not None:
352+
_old_data = self._columnwise_data
352353
self._columnwise_data = tex.fp8_transpose(
353354
self._columnwise_data, self._fp8_dtype, out=None
354355
)
356+
_old_data.data = _empty_tensor()
357+
del _old_data
355358

356359
def __repr__(self):
357360
if self._rowwise_data is not None:

0 commit comments

Comments
 (0)