From d668f18f4a2b93c92d9d63f8f6d14ab3f075ec0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 24 Feb 2025 14:50:49 +0100 Subject: [PATCH] [Pytorch] Added missing assert_dim_for_fp8_exec for Linear * fix Signed-off-by: Pawel Gadzinski * reshape inp Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e51513630f..bae21eebfd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -27,6 +27,7 @@ divide, init_method_constant, non_tn_fp8_gemm_supported, + assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, requires_grad, @@ -118,13 +119,14 @@ def forward( # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication nvtx_range_push(f"{nvtx_label}.input_cast_comm") - inputmat = inp + inputmat = inp.view(-1, in_features) inputmat_total = None with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False if fp8: + assert_dim_for_fp8_exec(inputmat, weight) if ( any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not FP8GlobalStateManager.get_fp8_recipe().delayed()