-
Notifications
You must be signed in to change notification settings - Fork 546
Fix memory overhead of linear layer when all gather from sequence parallel #2125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix memory overhead of linear layer when all gather from sequence parallel #2125
Conversation
|
Thanks! It does seem a problem if the tensor GC will happen after a long period of time, while enforcing GC directly brings too much overhead. CC @timmoon10 to also take a look at this PR. |
| _old_data = self._columnwise_data | ||
| self._columnwise_data = tex.fp8_transpose( | ||
| self._columnwise_data, self._fp8_dtype, out=None | ||
| ) | ||
| _old_data.data = _empty_tensor() | ||
| del _old_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't Python refcounting deallocate the old data automatically? If not, then there's a larger bug in tex.fp8_transpose that we should fix. The code will become unmanageable if we have to apply this kind of trick everywhere we call a C++ extension.
I see that we return at::Tensor instead of Pybind11 objects. I wonder if it is not properly handling the refcount.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a very good question that also confuses me. I have a similar question: why we need to use clear_tensor_data to release tensors that unused, but not let Python GC to deallocate them? I guess the root cause may be similar and my current fix is a similar trick with clear_tensor_data.
I'm not very familiar with the code and deeper implementation. Could you please share this problem with some TE experts to help solve it? Or do you think we could merge this PR first and try to find out the root cause later, because it is a little bit emergent for the runnability and perf of DSV3 / MLA long context training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We call clear_tensor_data to manually deallocate memory before Python GC. A good example is here in the LayerNormLinear backward:
| clear_tensor_data(ln_out_total) |
The forward GEMM input tensor is stored within the autograd
ctx, so GC will not deallocate until after the backward has finished. However, we don't need this buffer after the wgrad GEMM and ideally it would be reused for the LayerNorm grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I described in this thread, I think it is not the business of C++ extension, but the business of Float8BlockQuantizer.make_empty (in this case, the unreleased tensor is also created by make_empty, and the C++ extension is only creating the new tensor). In other words, I believe the issue is: tensors created by Float8BlockQuantizer.make_empty can only be released by manually deleting _columnwise_data.data and _rowwise_data.data. (I believe that's why we need to delete the _columnwise_data.data and _rowwise_data.data in clear_tensor_data, but not to delete _columnwise_data and _rowwise_data.)
8509ef6 to
57e22c5
Compare
| # Deallocate GEMM input tensor if no longer needed | ||
| if with_input_all_gather_nccl: | ||
| clear_tensor_data(inputmat_total) | ||
| inputmat_total = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this have an effect? The only benefit from deallocating the input tensor is that the buffer might be reused for the tensor-parallel reduce-scatter on the output. However, if we've done a tensor-parallel all-gather on the input, we won't do a reduce-scatter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will have an effect. I have just confirmed this by an experiment, and if I remove this line, the inputmat_total will be in the memory for a very long time.
I think that's the most serious problem: the data of FP8 tensors created by make_empty may not be released by Python GC as expected. Do you have any idea about it?
2e9603d to
3ceff63
Compare
Signed-off-by: Yuzhong Wang <[email protected]>
b458d30 to
b6443bb
Compare
for more information, see https://pre-commit.ci
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
Show resolved
Hide resolved
…ensor_base.py Signed-off-by: Tim Moon <[email protected]>
|
/te-ci pytorch L1 |
Signed-off-by: Yuzhong Wang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as Norm outputs' and activation compuations. Signed-off-by: Ming Huang <[email protected]> * Adding knob to control norm output precision. Signed-off-by: Ming Huang <[email protected]> * Removing the knob and applying lower-precision norm with current-scaling only. Signed-off-by: Ming Huang <[email protected]> * Fix the error when quantizer==None Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> [PyTorch] Support activation CPU offloading in fusible ops (#2158) * Add CPU offloading logic to ops. Fix test to compute dgrad. Signed-off-by: Tim Moon <[email protected]> * Make sure grads are contiguous in op backwards Signed-off-by: Tim Moon <[email protected]> * Add op-based MLP to CPU offloading tests Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle different weight cache behavior on Hopper/Blackwell Add MXFP8 to CPU offload tests. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove MXFP8 test Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174) * Do not use norm fwd + amax fusion if cudnn backend is requested Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Read envirornment vairable directly to avoid include error Signed-off-by: Jan Bielak <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix unjoined comm stream in UB communicator (#2160) Signed-off-by: djns99 <[email protected]> FP8 Output Quantization for GEMM (#2123) * Test working as I think it should work Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * revert accidental change Signed-off-by: Varun Thumbe <[email protected]> Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> fix merge conflict Signed-off-by: Varun Thumbe <[email protected]> bug: missed a } in the code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> missed a quant code removal Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor bug fix Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> minor code cleanup Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe <[email protected]> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <[email protected]> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe <[email protected]> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz <[email protected]> * fix for fp8 blockwise recipe Signed-off-by: zhongboz <[email protected]> * resolve comments Signed-off-by: zhongboz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> address review comments Signed-off-by: Varun Thumbe <[email protected]> * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 <[email protected]> * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe <[email protected]> * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * fix merge conflict Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe <[email protected]> Signed-off-by: vthumbe1503 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> TE Gemma tutorial attempt#2 (#1839) * add tutorial files and other local changes Signed-off-by: Sudhakar Singh <[email protected]> * remove extraneous code for easy debu Signed-off-by: Sudhakar Singh <[email protected]> * make cuda graphs work with non-paged and paged attention Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * perf imp for kv cache ops Signed-off-by: Sudhakar Singh <[email protected]> * add code for calibration Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes to make quantizers work with fp8_calibration Signed-off-by: Sudhakar Singh <[email protected]> * avoid reindexing from python side Signed-off-by: Charlene Yang <[email protected]> * rename variable from previous commit Signed-off-by: Charlene Yang <[email protected]> * minor fix Signed-off-by: Charlene Yang <[email protected]> * minor fix Signed-off-by: Charlene Yang <[email protected]> * use quantizer only if needed Signed-off-by: Sudhakar Singh <[email protected]> * functionality of the tutorial tested and perf checked Signed-off-by: Sudhakar Singh <[email protected]> * remove files and update headers/licenses Signed-off-by: Sudhakar Singh <[email protected]> * update header/license Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tutorial for review Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make weights downloadable on the fly; remove extra print statements Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and update comments Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comma back, typo Signed-off-by: Sudhakar Singh <[email protected]> * sequence_start_positions should be None for training Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add paged attention numberes and update requirements.txt file Signed-off-by: Sudhakar Singh <[email protected]> * more fixes Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make tutorial work on blackwell Signed-off-by: Sudhakar Singh <[email protected]> * remove gemma FT tutorial for now Signed-off-by: Sudhakar Singh <[email protected]> * fixing the headings placement and rewording attention -> kv caching Signed-off-by: Sudhakar Singh <[email protected]> * fixes from comments Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the images Signed-off-by: Sudhakar Singh <[email protected]> * misc fixes Signed-off-by: Sudhakar Singh <[email protected]> * add more comments to te_gemma.py and cleanup utils.py Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more information about the hierarchy of the classes used in the tutorial Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add better cuda graphs picture Signed-off-by: Sudhakar Singh <[email protected]> * addd updated cuda graphs pictures Signed-off-by: Sudhakar Singh <[email protected]> * add illustrated cuda graphs Signed-off-by: Sudhakar Singh <[email protected]> * fix Signed-off-by: Sudhakar Singh <[email protected]> * small fixes in documentation Signed-off-by: Sudhakar Singh <[email protected]> * add torch.no_grad() to force reduced memory usage Signed-off-by: Sudhakar Singh <[email protected]> * some fixes from recent comments Signed-off-by: Sudhakar Singh <[email protected]> * more fixes from remaining comments Signed-off-by: Sudhakar Singh <[email protected]> * add te_rope_emb to class desc Signed-off-by: Sudhakar Singh <[email protected]> * fix tutorial wording; add calibration fix to grouped_linear.py Signed-off-by: Sudhakar Singh <[email protected]> --------- Signed-off-by: Sudhakar Singh <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]> Fix memory overhead of linear layer when all gather from sequence parallel (#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <[email protected]> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <[email protected]> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Yuzhong Wang <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Fix incorrect TP rank calculation when using data parallel (#2179) Signed-off-by: djns99 <[email protected]> [Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045) * feat: add cutlass group gemm support Signed-off-by: Min Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor multi tensor gemm interface Signed-off-by: Min Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor nvte_multi_stream_cublas_gemm func and add license info Signed-off-by: Min Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add unit test for cutlass group gemm Signed-off-by: Min Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add cutlass support type protect Signed-off-by: Min Yang <[email protected]> * add tests and fix lint Signed-off-by: Xin Yao <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: fix unit tests error Signed-off-by: Min Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: refactor host workspace malloc Signed-off-by: Min Yang <[email protected]> * update cutlass Signed-off-by: Xin Yao <[email protected]> * update cutlass Signed-off-by: Xin Yao <[email protected]> * further relex threshold and add a env var to warn fall back Signed-off-by: Xin Yao <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Min Yang <[email protected]> Signed-off-by: Xin Yao <[email protected]> Signed-off-by: alan yang <[email protected]> Co-authored-by: Min Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [PyTorch] Support FA3 for MLA and with CP (#1907) feature(FA3,MLA,CP): 1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward 2. Update get_attention_backend method because FA3 support MLA now 3. Add CP MLA support for FA3 4. Add unit tests for FA3 MLA CP 5. Update attention doc Signed-off-by: zhujian <[email protected]> Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185) * Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix for cuDNN version condition check Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…allel (NVIDIA#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <[email protected]> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <[email protected]> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Yuzhong Wang <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]>
Description
gather_along_first_dim(all gather from sequence parallel region) inLinearandLayerNormLinearintroduces memory overhead:grad_output,inputmat_total, andln_out_total._post_process_fp8_blockwise_gather, it will callFloat8BlockwiseQTensorBase._transpose_columnwise_data, whose old implementation is as follows. It will make the old_columnwise_databecome unreachable and unable to be released.These two bugs are harmful especially for long context training, because the tensor gathered from SP region will be very large and the overhead is unacceptable. For example, in the MoE layer of DSV3, there will be 6 tensors that were not released in time; and in the dense layer of DSV3, there will be 18 tensors that were not released in time.
Suppose
seq_len=65536, there will be at least18 * 65536 * 7168 = 7.875 GBmemory allocation overhead, and this overhead will increase proportionally with the sequence length. Furthermore, this usually leads to much more overhead in device memory usage. Still take 64K context length as an example, the device memory usage decreases from 59187 MB to 40777 MB for DeepSeek-V3-TP16PP8EP32VPP4CP1-full-recompute with the fix. With this huge improvement, we could further turn into selective recomputation to greatly improve the E2E performance (267 TFLOPS -> 307 TFLOPS).Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: