Skip to content

Commit 83bc8d2

Browse files
yuzhongw-nvidiapre-commit-ci[bot]timmoon10
authored andcommitted
Fix memory overhead of linear layer when all gather from sequence parallel (NVIDIA#2125)
* fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <[email protected]> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <[email protected]> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Yuzhong Wang <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]>
1 parent 64c5581 commit 83bc8d2

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,11 @@ def forward(
353353

354354
# Deallocate GEMM input tensor if no longer needed
355355
if not weight.requires_grad and not return_layernorm_output:
356-
ln_out = ln_out_total = None
357356
clear_tensor_data(ln_out, ln_out_total)
357+
ln_out = ln_out_total = None
358+
elif with_input_all_gather and not return_layernorm_output_gathered:
359+
clear_tensor_data(ln_out_total)
360+
ln_out_total = None
358361

359362
# ------------------------------------------------------
360363
# Prepare output tensor
@@ -891,9 +894,19 @@ def wgrad_gemm(
891894
grad_bias = grad_bias_
892895
del grad_bias_
893896

894-
# Deallocate input tensor if permitted
895-
if not ctx.return_layernorm_output:
897+
# Deallocate input tensors if permitted
898+
if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
899+
# Input tensors have not been exposed externally
900+
clear_tensor_data(ln_out)
901+
elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered:
902+
# Non-gathered input has not been exposed externally
903+
clear_tensor_data(ln_out)
904+
if ctx.ln_out_needs_gather:
905+
# Gathered input is internal
896906
clear_tensor_data(ln_out_total)
907+
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
908+
# Gathered grad output tensor is internal
909+
clear_tensor_data(grad_output)
897910

898911
# Update grad input if overlapping reduce-scatter with wgrad GEMM
899912
if ctx.ub_bulk_wgrad:
@@ -1169,7 +1182,9 @@ def __init__(
11691182
self.return_bias = return_bias
11701183
self.apply_bias = self.use_bias and not return_bias
11711184
self.return_layernorm_output = return_layernorm_output
1172-
self.return_layernorm_output_gathered = return_layernorm_output_gathered
1185+
self.return_layernorm_output_gathered = (
1186+
return_layernorm_output_gathered if return_layernorm_output else False
1187+
)
11731188
self.zero_centered_gamma = zero_centered_gamma
11741189
self.symmetric_ar_type = symmetric_ar_type
11751190

transformer_engine/pytorch/module/linear.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,13 @@ def forward(
317317
# Finished forward GEMM...
318318
# ------------------------------------------------------
319319

320+
# Deallocate GEMM input tensor if no longer needed
321+
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
322+
# deallocated by GC. Manually deallocating is a temporary hack.
323+
if with_input_all_gather_nccl:
324+
clear_tensor_data(inputmat_total)
325+
inputmat_total = None
326+
320327
# ------------------------------------------------------
321328
# Prepare output tensor
322329
# Note: Perform tensor-parallel communication
@@ -878,9 +885,16 @@ def wgrad_gemm(
878885
grad_bias = grad_bias_
879886
del grad_bias_
880887

881-
# Deallocate input tensor if permitted
888+
# Deallocate tensors if permitted
882889
if ctx.owns_input:
890+
# Input tensor is internal
891+
clear_tensor_data(inputmat_total)
892+
elif ctx.backward_input_needs_gather:
893+
# Gathered input tensor is internal
883894
clear_tensor_data(inputmat_total)
895+
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
896+
# Gathered grad output tensor is internal
897+
clear_tensor_data(grad_output)
884898

885899
# Update grad input if overlapping reduce-scatter with wgrad GEMM
886900
if ctx.ub_bulk_wgrad:

transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,14 @@ def _create_columnwise(self):
349349
def _transpose_columnwise_data(self):
350350
"""Plainly transpose the columnwise data and scale inv."""
351351
if self._columnwise_data is not None:
352+
# TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically
353+
# deallocated by GC. Manually deallocating is a temporary hack.
354+
_old_data = self._columnwise_data
352355
self._columnwise_data = tex.fp8_transpose(
353356
self._columnwise_data, self._fp8_dtype, out=None
354357
)
358+
_old_data.data = _empty_tensor()
359+
del _old_data
355360

356361
def __repr__(self):
357362
if self._rowwise_data is not None:

transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,13 @@ def __new__(
9595
return instance
9696

9797
def clear(self):
98-
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
99-
for t in (self._data, self._transpose, self._scale_inv):
98+
"""Deallocate this tensor's memory. Typically not needed and must be used carefully.
99+
100+
Scale-inv tensor is not deallocated because it's often shared
101+
between multiple FP8 tensors.
102+
103+
"""
104+
for t in (self._data, self._transpose):
100105
if t is not None:
101106
t.data = _empty_tensor()
102107
self._transpose_invalid = True

0 commit comments

Comments
 (0)