Skip to content

Commit cdf9485

Browse files
committed
revert accidental change
Signed-off-by: Varun Thumbe <[email protected]> Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> fix merge conflict Signed-off-by: Varun Thumbe <[email protected]> bug: missed a } in the code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> missed a quant code removal Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor bug fix Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> minor code cleanup Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe <[email protected]> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <[email protected]> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe <[email protected]> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz <[email protected]> * fix for fp8 blockwise recipe Signed-off-by: zhongboz <[email protected]> * resolve comments Signed-off-by: zhongboz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> address review comments Signed-off-by: Varun Thumbe <[email protected]>
1 parent 63d631c commit cdf9485

File tree

109 files changed

+3461
-683
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+3461
-683
lines changed

docs/api/pytorch.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pyTorch
4949

5050
.. autoapifunction:: transformer_engine.pytorch.moe_permute
5151

52-
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
52+
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
5353

5454
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
5555

@@ -62,3 +62,6 @@ pyTorch
6262
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
6363

6464
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
65+
66+
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
67+
:members: FP8, NONE

docs/examples/onnx/onnx_export.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"<b>Note:</b>\n",
1212
"\n",
13-
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n",
13+
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n",
1414
"\n",
1515
"</div>\n",
1616
"\n",

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
263263
te.module.base.initialize_ub(
264264
[batched_size, hidden_size],
265265
tp_size,
266-
use_fp8=opts.fp8,
266+
quantization_modes=[
267+
(
268+
te.module.base.UserBufferQuantizationMode.FP8
269+
if opts.fp8
270+
else te.module.base.UserBufferQuantizationMode.NONE
271+
)
272+
],
267273
dtype=torch.bfloat16,
268274
bootstrap_backend=opts.bootstrap_backend,
269275
)

qa/L0_cppunittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp
1717
cmake -GNinja -Bbuild .
1818
cmake --build build
1919
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
20-
ctest --test-dir build -j$NUM_PARALLEL_JOBS
20+
ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)'

qa/L1_cpp_distributed/test.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
set -e
6+
7+
# Find TE
8+
: ${TE_PATH:=/opt/transformerengine}
9+
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
10+
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
11+
12+
cd $TE_PATH/tests/cpp
13+
cmake -GNinja -S. -Bbuild
14+
cmake --build build
15+
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ set -xe
99
mkdir -p "$XML_LOG_DIR"
1010

1111
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
12+
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh

setup.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
"""Installation script."""
66

7+
from importlib import metadata
78
import os
89
import time
910
from pathlib import Path
@@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension:
6667
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
6768
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
6869

70+
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
71+
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
72+
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
73+
"nvidia-cublasmp-cu12"
74+
).locate_file("nvidia/cublasmp/cu12")
75+
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
76+
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
77+
"nvidia-nvshmem-cu12"
78+
).locate_file("nvidia/nvshmem")
79+
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
80+
print("CMAKE_FLAGS:", cmake_flags[-2:])
81+
6982
# Add custom CMake arguments from environment variable
7083
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
7184
if nvte_cmake_extra_args:

tests/cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
3737
message(STATUS "Found transformer_engine library: ${TE_LIB}")
3838
include_directories(../../transformer_engine/common/include)
3939
include_directories(../../transformer_engine/common)
40+
include_directories(../../transformer_engine)
4041
include_directories(${CMAKE_SOURCE_DIR})
4142

4243
find_package(CUDAToolkit REQUIRED)

tests/cpp/comm_gemm/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
add_executable(test_comm_gemm
6+
test_comm_gemm.cu
7+
../test_common.cu)
8+
9+
find_package(OpenMP REQUIRED)
10+
find_package(MPI REQUIRED)
11+
find_library(NCCL_LIB
12+
NAMES nccl libnccl
13+
PATH_SUFFIXES lib
14+
REQUIRED)
15+
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include)
16+
target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX)
17+
18+
include(GoogleTest)
19+
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)

0 commit comments

Comments
 (0)