Skip to content

Commit

Permalink
[Pytorch] Added missing assert_dim_for_fp8_exec for Linear
Browse files Browse the repository at this point in the history
* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* reshape inp

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
  • Loading branch information
pggPL authored Feb 24, 2025
1 parent 7f2dcf9 commit d668f18
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
divide,
init_method_constant,
non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
Expand Down Expand Up @@ -118,13 +119,14 @@ def forward(
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp
inputmat = inp.view(-1, in_features)
inputmat_total = None
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if (
any([ub_overlap_ag_fprop, ub_overlap_rs_fprop])
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
Expand Down

0 comments on commit d668f18

Please sign in to comment.