diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 9ca052e761..438ab3d8fd 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -2,23 +2,27 @@ # # See LICENSE for license information. -set -e : ${TE_PATH:=/opt/transformerengine} pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py -pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py -pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py -NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py -pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py -pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py -pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py -pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py + +FAIL=0 + +pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1 +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || FAIL=1 + +exit $FAIL \ No newline at end of file diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 8ee0be1af5..5e3823d85c 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -2,14 +2,17 @@ # # See LICENSE for license information. -set -e - : ${TE_PATH:=/opt/transformerengine} pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py + +FAIL=0 + +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1 # pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1 + +exit $FAIL diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e51513630f..bae21eebfd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -27,6 +27,7 @@ divide, init_method_constant, non_tn_fp8_gemm_supported, + assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, requires_grad, @@ -118,13 +119,14 @@ def forward( # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication nvtx_range_push(f"{nvtx_label}.input_cast_comm") - inputmat = inp + inputmat = inp.view(-1, in_features) inputmat_total = None with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False if fp8: + assert_dim_for_fp8_exec(inputmat, weight) if ( any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not FP8GlobalStateManager.get_fp8_recipe().delayed()