Skip to content

Commit

Permalink
Only cache column-wise input in LayerNormLinear
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed Feb 25, 2025
1 parent 03d95e5 commit 2099726
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ def forward(
clear_tensor_data(ln_out, ln_out_total)

if is_grad_enabled:

# Input with column-wise usage is needed for dgrad GEMM
if backward_needs_input:
if isinstance(ln_out, QuantizedTensor):
ln_out.update_usage(rowwise_usage=False)

if cpu_offloading:
if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True)
Expand Down

0 comments on commit 2099726

Please sign in to comment.