Skip to content

Conversation

Jianbing-D
Copy link

@Jianbing-D Jianbing-D commented Aug 12, 2025

Description

This pull request involves efficient implementations for mxfp8 quantize on casting only cases. It can increase the casting performance from 5%~ 20%.

It supports:

  • BF16 or FP16 as inputs
  • E5M2 or E4M3 as outputs
  • gpu arch >= sm_100
  • rowwise or row- & col-wise

Performance gain:
image
image
image
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added an environment ENABLE_CAST_ONLY to select optimized kernel. If optimized kernel doesn't support provided inputs, it will fallback to original kernels, automatically.
    1. If ENABLE_CAST_ONLY is not set or is set to 0, then original kernels will be used.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Jianbing-D Jianbing-D force-pushed the feat-fast-cast-mxfp8 branch 3 times, most recently from 00fbe4f to 7295b7d Compare August 12, 2025 04:30
@Jianbing-D
Copy link
Author

Jianbing-D commented Aug 12, 2025

Steps to reproduce performance numbers

  1. start a container with image nvcr.io/nvidia/pytorch:25.06-py3 on GB200 clusters
  2. uninstall the pre-installed TE
  3. manually installed this branch with export PYTHONUSERBASE=/tmp/python unset PIP_CONSTRAINT && NVTE_CUDA_ARCHS="100a" NVTE_BUILD_THREADS_PER_JOB=8 NVTE_FRAMEWORK=pytorch pip install --no-build-isolation -v -e ./TransformerEngine
  4. Run the following scripts with NCU, which will tell you the kernel duration and memory bandwidth
ncu --section=MemoryWorkloadAnalysis --section=SpeedOfLight  --clock-control=none --nvtx --nvtx-include="Update Quantized/" --nvtx-include="reference kernel/"  python quantize.py
# quantize.py

import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp"
os.environ["TRITON_CACHE_DIR"] = "/tmp"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"

if "USER" not in os.environ:
    os.environ["USER"] = "you"

import torch

from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex

import argparse

torch.cuda.manual_seed(233376)

def run(args):
    check_consistency = args.check_consistency

    direction = {"rowwise": not args.no_rowwise, "colwise": not args.no_colwise}
    src_dtype = args.src_dtype
    dst_dtype = args.dst_dtype
    size_h, size_w = args.size_h, args.size_w

    msg_candidates = {"TrueTrue": "rowwise and colwise",
                    "TrueFalse": "rowwise",
                    "FalseTrue": "colwise",
                    "FalseFalse": None}
    msg = msg_candidates[f"{direction['rowwise']}{direction['colwise']}"]
    if msg is None:
        raise ValueError(f"Invalid direction: {direction}")
    print("=" * 120)
    print(f"checking {msg}, "
          f"src_dtype: {src_dtype}, dst_dtype: {dst_dtype}, size_h: {size_h}, size_w: {size_w}")
    print("=" * 120)

    with torch.cuda.nvtx.range("Ctor"):
        quantizer = MXFP8Quantizer(
            fp8_dtype=dst_dtype,
            rowwise=direction["rowwise"],
            columnwise=direction["colwise"],
        )

    with torch.cuda.nvtx.range("Create Input"):
        bf16_tensor = torch.randn(size_h, size_w, dtype=src_dtype, device="cuda")
        # bf16_tensor = torch.arange(size_h * size_w, dtype=src_dtype, device="cuda").reshape(size_h, size_w)
        # # Print every element in bf16_tensor
        # print("Elements of bf16_tensor:")
        # for i in range(bf16_tensor.shape[0]):
        #     print("row: ", i, end=": ")
        #     for j in range(bf16_tensor.shape[1]):
        #         print(f"{bf16_tensor[i, j].item():.4f}\t", end="")
        #     print()
        # amax = torch.abs(bf16_tensor).amax(axis=0, keepdim=True)
        # print(amax)

    if check_consistency:
        with torch.cuda.nvtx.range("reference"):
            fp8_tensor_ref = quantizer.make_empty(
                bf16_tensor.shape,
                dtype=bf16_tensor.dtype,
                device=bf16_tensor.device,
            )
            with torch.cuda.nvtx.range("reference kernel"):
                quantizer.update_quantized(bf16_tensor, fp8_tensor_ref)


    with torch.cuda.nvtx.range("Make Empty"):
        fp8_tensor = quantizer.make_empty(
            bf16_tensor.shape,
            dtype=bf16_tensor.dtype,
            device=bf16_tensor.device,
        )


    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    if check_consistency:
        os.environ["ENABLE_CAST_ONLY"] = "1"
    start.record()
    with torch.cuda.nvtx.range("Update Quantized"):
        quantizer.update_quantized(bf16_tensor, fp8_tensor)
    end.record()

    torch.cuda.synchronize()

    ms = start.elapsed_time(end)


    io_bytes = size_h * size_w * 2
    io_bytes += size_h * size_w * 1
    io_bytes += size_h * (size_w // 32) * 1
    print(f"Io Bytes: {io_bytes / 1e6} MB")
    print(f"Duration: {ms} ms")
    print(f"Bandwidth: {(io_bytes * 1e-9) / (ms * 1e-3)} GB/s")

    # print(fp8_tensor)
    if check_consistency:
        # print(fp8_tensor_ref)

        if direction["rowwise"]:
            torch.testing.assert_close(fp8_tensor._rowwise_data, fp8_tensor_ref._rowwise_data)
            print("rowwise data passed")

            # print(fp8_tensor._rowwise_scale_inv.shape)
            # for i in range(fp8_tensor._rowwise_scale_inv.shape[0]):
            #     print(f"row: {i}", end=": ")
            #     for j in range(fp8_tensor._rowwise_scale_inv.shape[1]):
            #         print(f"{fp8_tensor._rowwise_scale_inv[i, j].item():d},", end="")
            #     print("")
            # print("-------------ref tensor-------------------")
            # for i in range(fp8_tensor_ref._rowwise_scale_inv.shape[0]):
            #     print(f"row: {i}", end=": ")
            #     for j in range(fp8_tensor_ref._rowwise_scale_inv.shape[1]):
            #         print(f"{fp8_tensor_ref._rowwise_scale_inv[i, j].item():d},", end="")
            #     print("")

            torch.testing.assert_close(fp8_tensor._rowwise_scale_inv, fp8_tensor_ref._rowwise_scale_inv)
            print("rowwise scale_inv passed")
        if direction["colwise"]:
            torch.testing.assert_close(fp8_tensor._columnwise_data, fp8_tensor_ref._columnwise_data)
            print("colwise data passed")
            torch.testing.assert_close(fp8_tensor._columnwise_scale_inv, fp8_tensor_ref._columnwise_scale_inv)
            print("colwise scale_inv passed")
        torch.testing.assert_close(fp8_tensor, fp8_tensor_ref)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Profile MXFP8 quantization")
    parser.add_argument("--no_rowwise", action="store_true", default=False, help="Enable rowwise quantization")
    parser.add_argument("--no_colwise", action="store_true", default=False, help="Enable colwise quantization")
    parser.add_argument("--src_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16"], help="Source dtype")
    parser.add_argument("--dst_dtype", type=str, default="kFloat8E4M3", choices=["kFloat8E4M3", "kFloat8E5M2"], help="Destination dtype")
    parser.add_argument("--size_h", type=int, default=4096, help="Input tensor height")
    parser.add_argument("--size_w", type=int, default=7168, help="Input tensor width")
    parser.add_argument("--check_consistency", action="store_true", default=True, help="Check consistency")
    args = parser.parse_args()

    if args.src_dtype == "bfloat16":
        src_dtype = torch.bfloat16
        args.src_dtype = src_dtype
    elif args.src_dtype == "float16":
        src_dtype = torch.float16
        args.src_dtype = src_dtype
    elif args.src_dtype == "float32":
        src_dtype = torch.float32
        args.src_dtype = src_dtype
    else:
        raise ValueError(f"Unsupported src_dtype: {args.src_dtype}")

    if args.dst_dtype == "kFloat8E4M3":
        dst_dtype = tex.DType.kFloat8E4M3
        args.dst_dtype = dst_dtype
    elif args.dst_dtype == "kFloat8E5M2":
        dst_dtype = tex.DType.kFloat8E5M2
        args.dst_dtype = dst_dtype
    else:
        raise ValueError(f"Unsupported dst_dtype: {args.dst_dtype}")

    run(args)

@yaox12 yaox12 requested a review from timmoon10 August 12, 2025 05:14
@Jianbing-D Jianbing-D force-pushed the feat-fast-cast-mxfp8 branch from 623f960 to 0089f01 Compare August 13, 2025 03:42
Jianbing Dong and others added 22 commits August 12, 2025 20:43
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
@Jianbing-D Jianbing-D force-pushed the feat-fast-cast-mxfp8 branch from 0089f01 to c0ae662 Compare August 13, 2025 03:43
@Jianbing-D Jianbing-D requested a review from ptrendx August 13, 2025 03:46
@Oleg-Goncharov Oleg-Goncharov self-requested a review August 28, 2025 09:50
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid kernel optimizations, the perf comes right when we need it.


break;
}
case ScalingType::COLWISE: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please modify the logic such that the column-wise scaling is not skipped and the original implementation is used as fall back?


#include "state_counter.cuh"
#include "swizzle.cuh"

Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please include "ptx.cuh", as the code isn't compiling


namespace {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <typename IType, typename OType>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use the members of the original Quantized_Limits class and derive their bit representation in compile time using, e.g. bit_cast, to ensure the limits defined in one place?

coords.x = block_coords.x + iter_n * CastTraits::blockIterDimN;

if (coords.y < rows && coords.x < cols) {
int32_t offset = coords.y * cols + coords.x;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use size_t data type for offsets, especially when global memory is addressed, as int32_t overflows when tensors are really large. (We fixed this problem just recently, as some TE kernels used 4-byte ints previously for offsets)

smem_alignment + smem_rowwise_scale + smem_colwise_reduce);
};

#define ALIGN_TO(x, align) (((x) + (align) - 1) & ~((align) - 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace the macro with the device function.


extern __shared__ char smem[];
char *smemAligned =
reinterpret_cast<char *>(ALIGN_TO((intptr_t)smem, CastTraits::smem_alignment));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace C-style casting with static_cast<>.

float2 amaxs;
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wrap the inline PTX into device functions with some reasonable names (e.g., which resemble the operation, like redux_sync_max_abs_f32() here), and move them to ptx.cuh

: "=f"(amaxs.y)
: "f"(values.y));
#else
asm volatile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, and in other places with "raw" inline ptx.

extern __shared__ char smem[];
char *smemAligned =
reinterpret_cast<char *>(ALIGN_TO((intptr_t)smem, CastTraits::smem_alignment));
typename CastTraits::IType *sInput = reinterpret_cast<typename CastTraits::IType *>(smemAligned);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: as long types are used few times in the kernel, and it may better to use type aliasing to improve readability of the code. E.g.:

using IType =  typename CastTraits::IType;
using OType = typename CastTraits::rowOutputUnitType;
using ColwiseReduceType = typename CastTraits::ColwiseReduceDataType;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants