diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e388dd794b..a979c896fe 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -168,7 +168,7 @@ list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu - util/cast.cu + cast/cast.cu activation/gelu.cu activation/relu.cu activation/swiglu.cu @@ -336,8 +336,7 @@ option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --u if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu activation/relu.cu - activation/swiglu.cu - util/cast.cu) + activation/swiglu.cu) endif() foreach(cuda_source IN LISTS nvte_sources_with_fast_math) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 1d9a3fb43c..7353c3e1d5 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -14,26 +14,17 @@ #include #include +#include "../cast/dispatch/gated.cuh" +#include "../cast/dispatch/quantize.cuh" #include "../common.h" -#include "../util/cast_gated_kernels.cuh" -#include "../util/cast_kernels.cuh" -#include "../util/math.h" -#include "../util/vectorized_pointwise.h" namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; constexpr bool IS_ACT = true; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template @@ -42,20 +33,17 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, using namespace detail; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DGATED = false; - constexpr NVTETensor grad = nullptr; - quantize_gated_helper(grad, input, output, p, stream); + dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template (grad, input, output, p, stream); + dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 4949ba5906..4979023ef1 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; @@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index c74fc6eee9..c0ef9fd65a 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; @@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index cafc48abba..6957a91e61 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu new file mode 100644 index 0000000000..1ed46a3359 --- /dev/null +++ b/transformer_engine/common/cast/cast.cu @@ -0,0 +1,102 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/multi_stream.h" +#include "../utils.cuh" +#include "dispatch/dequantize.cuh" +#include "dispatch/quantize.cuh" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_fwd_helper(input, output, nullptr, stream); +} + +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; + + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_fwd_helper(input, output, quant_config, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTETensor activation_input = nullptr; + + dispatch::quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); + using namespace transformer_engine; + dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), + stream); +} + +void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, + const NVTEQuantizationConfig quant_configs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + + const size_t num_streams = nvte_get_num_compute_streams(); + + int num_stream_used = std::min(num_streams, num_tensors); + // wait for current stream to finish + NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); + } + + for (int i = 0; i < num_tensors; i++) { + dispatch::quantize_fwd_helper( + inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); + } + + // record events on compute streams + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); + } + // wait for all compute streams to finish + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); + } +} diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh new file mode 100644 index 0000000000..b750142f5b --- /dev/null +++ b/transformer_engine/common/cast/core/common.cuh @@ -0,0 +1,97 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file common.cuh + * \brief Common functions in quantize. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace common { +inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % elems_per_block == 0); + return isFullTile; +} + +inline bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr size_t TMA_bytes = 16; + const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + return cols % alignment_requirement == 0; +} + +namespace kernel { + +constexpr size_t THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, + const size_t rows, const size_t cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} +} // namespace kernel + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace common +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh new file mode 100644 index 0000000000..b8547915c7 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -0,0 +1,56 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize.cuh + * \brief Dequantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../fp8/dequantize_fp8.cuh" +#include "../mxfp8/dequantize_mxfp8.cuh" +#include "../nvfp4/dequantize_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { + +inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh new file mode 100644 index 0000000000..4373090b72 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated.cuh + * \brief Gated dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ + +#include + +#include "../../common.h" +#include "../../utils.cuh" +#include "../fp8/gated_fp8.cuh" +#include "../mxfp8/gated_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, + cudaStream_t stream) { + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols, + "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + Tensor dummy_grad_tensor; + fp8::cast_gated_tma(input, dummy_grad_tensor, + output, p, stream); + } else { + fp8::cast_gated_fwd(input, output, p, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + Tensor dummy_grad_tensor; + mxfp8::quantize_gated(input, dummy_grad_tensor, + output, p, stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +template +void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, + NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(grad, "grad"); + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + gated_input.flat_last_dim(), "."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + + NVTE_CHECK(grad.flat_first_dim() == rows, + "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + NVTE_CHECK(grad.flat_last_dim() == cols, + "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols * 2, + "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(gated_input.data.shape == output->data.shape, + "Gated input and output shapes must match. Input shape: ", gated_input.data.shape, + ", output shape: ", output->data.shape, "."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + fp8::cast_gated_tma(gated_input, grad, output, p, + stream); + } else { + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + mxfp8::quantize_gated(gated_input, grad, output, p, + stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } +} +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh new file mode 100644 index 0000000000..9f7a4a9b01 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -0,0 +1,326 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize.cuh + * \brief Quantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/vectorized_pointwise.h" +#include "../core/common.cuh" +#include "../fp8/quantize_fp8.cuh" +#include "../mxfp8/quantize_mxfp8.cuh" +#include "../nvfp4/quantize_nvfp4.cuh" +#include "../nvfp4/quantize_transpose_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_fwd_helper(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + Tensor *output_tensor = convertNVTETensorCheck(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + mxfp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +template +void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *grad_tensor = convertNVTETensorCheck(grad); + const Tensor *input_tensor = convertNVTETensor(input); + + Tensor *output_tensor = convertNVTETensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT) { + cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*grad_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = grad_tensor->flat_first_dim(); + int32_t cols = grad_tensor->flat_last_dim(); + auto dtype = grad_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh new file mode 100644 index 0000000000..2514758b5a --- /dev/null +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_fp8.cuh + * \brief CUDA kernels to dequantize from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + DequantizeParam p; p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh new file mode 100644 index 0000000000..225ef93ed9 --- /dev/null +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -0,0 +1,394 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated_fp8.cuh + * \brief CUDA kernels to cast to FP8 with gated activations. + */ + +#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_FP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +namespace kernel { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t grad_mem = IS_BWD ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_BWD) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const size_t buff = it % BUFFERS_NUM; + const size_t next_it = it + 1; + if (next_it < ITERATIONS) { + const size_t next_buff = next_it % BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_BWD) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + } + + if constexpr (IS_BWD) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); + act_x = x * s; + if (act_elt <= p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; + } else { + dact_x = 0.0f; + } + } else { + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } + } + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, p) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_BWD) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace kernel + +template +void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, + cudaStream_t stream) { + using namespace kernel; + checkCuDriverContext(stream); + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_BWD) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, typeToNumBits(gated_input.dtype())); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, + typeToNumBits(output->dtype())); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; + + auto kernel = cast_fp8_gated_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), + output->flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated_bwd(const Tensor &input, const Tensor &grad, Tensor *output, ParamOP &p, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh new file mode 100644 index 0000000000..efc5015b75 --- /dev/null +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -0,0 +1,580 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_fp8.cuh + * \brief CUDA kernels to quantize to FP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +namespace quantize_2D_kernel { + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + const size_t dbias_offset_Y = blockIdx.y + tid_Y; + const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const size_t dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const size_t chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const size_t buff = iter % FP8_BUFFERS_NUM; + const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const size_t dbias_offset_X = my_column; + const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_2D_kernel + +namespace quantize_1D_kernel { +using namespace quantize_2D_kernel; + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const size_t buff = iter % SHMEM_BUFFERS; + const size_t it_offset = iter * SHMEM_DIM; + + const size_t next_iter = iter + 1; + const size_t next_buff = next_iter % SHMEM_BUFFERS; + const size_t next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_1D_kernel + +template +void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace quantize_1D_kernel; + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + using namespace quantize_2D_kernel; + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + NVTE_CHECK_CUDA(cudaGetLastError()); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { +using Empty = transformer_engine::Empty; +__device__ inline float identity(float value, const Empty &) { return value; } +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_1D_kernel; + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + if (!IS_DBIAS && !IS_DACT) { + if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 + quantize_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (common::dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 (+dAct) + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } + } else { + if (IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + + " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } +} + +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh similarity index 69% rename from transformer_engine/common/util/dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 9f70ce4cd4..fb43fce96b 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -4,36 +4,27 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file dequantize_kernels.cuh - * \brief CUDA kernels to cast from MXFP8. +/*! \file dequantize_mxfp8.cuh + * \brief CUDA kernels to dequantize from MXFP8. */ -#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ #include #include #include -#include - -#include -#include -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transformer_engine.h" -#include "transformer_engine/transpose.h" +#include -namespace transformer_engine { +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" -namespace dequantization { +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace dequantize_kernel { constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -228,29 +219,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace dequantize_kernel -void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), nullptr, - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) -} - -void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace dequantize_kernel; bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); @@ -334,113 +306,8 @@ void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } - -#if CUDA_VERSION >= 12080 -template -__global__ void __launch_bounds__(512) - dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride) { - const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t x = thread_idx % M; - const size_t y = thread_idx / M; - - union fp4vec { - uint64_t vec; - fp4e2m1x4 small_vec[4]; - }; - using OVec = Vec; - const uint64_t *const input_vectorized = reinterpret_cast(input); - OVec *output_vec = reinterpret_cast(output); - - const size_t my_index = x + y * M; - const size_t my_scale_index = x + y * scale_stride; - const size_t my_output_index = (x + y * M) * 4; - fp4vec value; - value.vec = input_vectorized[my_index]; - fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); - float final_scale = static_cast(scale) * amax * factor_inv; -#pragma unroll - for (int i = 0; i < 4; i++) { - float4 current = static_cast(value.small_vec[i]); - OVec out; - out.data.elt[0] = static_cast(current.x * final_scale); - out.data.elt[1] = static_cast(current.y * final_scale); - out.data.elt[2] = static_cast(current.z * final_scale); - out.data.elt[3] = static_cast(current.w * final_scale); - output_vec[my_output_index + i] = out; - } -} -#endif // CUDA_VERSION - -void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { -#if CUDA_VERSION >= 12080 - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output"); - NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); - NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - constexpr int FP4_BLOCK_SIZE = 16; - const size_t N = input.flat_first_dim(); - const size_t M = input.flat_last_dim(); - - NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", - FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); - - const size_t Mread = M / FP4_BLOCK_SIZE; - const size_t total = N * Mread; - const size_t threads = 512; - const size_t blocks = DIVUP(total, threads); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, - input.scale_inv.shape.back());); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif // CUDA_VERSION >= 12080 -} - -} // namespace dequantization - -namespace detail { - -void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - switch (input.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - dequantization::fp8_dequantize(input, output, stream); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - break; - } - case NVTE_NVFP4_1D_SCALING: { - dequantization::fp4_dequantize(input, output, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } -} - -} // namespace detail - +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh similarity index 53% rename from transformer_engine/common/util/cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 93086bd827..4f0e1b80f7 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -4,280 +4,27 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file cast_gated_kernels.cuh - * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. +/*! \file gated_mxfp8.cuh + * \brief CUDA kernels to cast to MXFP8 with gated activations. */ -#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ #include #include #include -#include -#include +#include -#include - -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" namespace transformer_engine { - -namespace gated_kernels { - -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 512; -constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 -constexpr size_t BUFFERS_NUM = 2; -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); - -__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act, - const __grid_constant__ CUtensorMap tensor_map_output_gate, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols, - const ParamOP p) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - - constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t in_transaction_size = buff_elems * sizeof(IType); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); - const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - // Prefetch data of the first stage - - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, - TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, - chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } else { - copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } - -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const size_t buff = it % BUFFERS_NUM; - const size_t next_it = it + 1; - if (next_it < ITERATIONS) { - const size_t next_buff = next_it % BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3( - &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, - &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, - in_transaction_size, &mbar[next_it], is_master_thread); - } else { - copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, - chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, - &mbar[next_it], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_sh_curr = out_act_sh + buff * buff_elems; - OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; -#pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; - } - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - - const float x = act_elt; - float act_x; - float dact_x; - if constexpr (std::is_same::value) { - const float x = min(act_elt, p.limit); - const float s = sigmoidf(p.alpha * x); - act_x = x * s; - if (act_elt <= p.limit) { - dact_x = s + s * (1 - s) * p.alpha * x; - } else { - dact_x = 0.0f; - } - } else { - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, p); - dact_x = DActOP(x, p); - } - } - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; - - out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, p) * gate_elt; - out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // dGeLU - ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, - chunk_it_offset_y, - reinterpret_cast(out_act_sh_curr)); - - if constexpr (IS_DGATED) { - // dGate - ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_sh_curr)); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -namespace mxfp8_kernel { +namespace dispatch { +namespace mxfp8 { +namespace gated_kernel { constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; @@ -302,20 +49,21 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 -template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, const ParamOP p) { + quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -385,14 +133,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -427,7 +175,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) int parity = 0; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, @@ -454,7 +202,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], @@ -497,7 +245,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; } - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; @@ -526,20 +274,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_elt = static_cast(static_cast(after_gate_elt)); } } after_act_colwise[i] = after_act_elt; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_colwise[i] = after_gate_elt; } // Cache computed activations to avoid computing them again in the 2nd pass along another dimension if constexpr (IS_CACHED_ACT_OP) { cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); } } @@ -549,7 +297,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (!out_of_bounds) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); } } @@ -578,7 +326,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // All threads read the reduced amax (ACT) thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { // Make sure the previous read of the ACT values has been completed, // so the data are not rewritten __syncthreads(); @@ -622,7 +370,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); float block_scale_inverse_gate; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); @@ -639,7 +387,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { const size_t shmem_offset_elt = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { OType2 out_pair; ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, @@ -685,7 +433,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Load cached elements in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); } // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) @@ -695,7 +443,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); } } @@ -705,7 +453,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], in_cached_act[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], in_cached_gate[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); @@ -717,7 +465,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (!std::is_same_v) { thread_amax_act = static_cast( __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = static_cast( __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); } @@ -735,7 +483,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) in_act.load_from(&in_act_sh[shmem_offset_rowwise]); in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); } @@ -753,7 +501,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; } - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; @@ -786,7 +534,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_elt = static_cast(static_cast(after_gate_elt)); } } @@ -796,7 +544,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); } } @@ -822,7 +570,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_scale_inverse_gate; ptx::floatx2 block_scale_inverse_2x_gate; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; @@ -853,7 +601,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { IType2 in_gate; OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); @@ -873,7 +621,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); } } @@ -894,7 +642,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); @@ -904,7 +652,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); @@ -920,94 +668,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -} // namespace mxfp8_kernel +} // namespace gated_kernel -template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, +void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, cudaStream_t stream) { - checkCuDriverContext(stream); - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act{}; - alignas(64) CUtensorMap tensor_map_output_gate{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { + using namespace gated_kernel; checkCuDriverContext(stream); const bool USE_ROWWISE_SCALING = output->has_data(); @@ -1031,17 +698,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; - constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; @@ -1073,7 +734,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out constexpr size_t input_type_bit_size = TypeInfo::size; constexpr size_t output_type_bit_size = TypeInfo::size; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, input_type_bit_size); } @@ -1110,238 +771,68 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; switch (scaling_type) { - case ScalingType::ROWWISE: + case ScalingType::ROWWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::COLWISE: + } + case ScalingType::COLWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::BIDIMENSIONAL: + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, - "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), p, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, - cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, - "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match. Input shape: ", input.data.shape, - ", output shape: ", output->data.shape, "."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), p, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { - constexpr bool allow_empty = false; - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", allow_empty); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - if constexpr (IS_DGATED) { - CheckInputTensor(grad, "grad"); - NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); - NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); - } - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - bool is_fp8_rowwise_output = true; - bool is_fp8_colwise_output = true; - if (output->has_data()) { - is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - if (output->has_columnwise_data()) { - is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - - const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; - - if (is_delayed_tensor_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, p, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, p, stream); - } else { - cast_gated(gated_input, output, p, stream); - } - } - } else if (is_mxfp8_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, p, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } - } else { - NVTE_ERROR("Not supported scaling mode"); - } -} -} // namespace gated_kernels - -namespace detail { - -template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - ParamOP p, cudaStream_t stream) { - using namespace gated_kernels; - Tensor grad_empty_tensor; - const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); - Tensor *output_tensor = convertNVTETensorCheck(output); - - if (is_supported_by_CC_100()) { - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, p, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, - stream); - } else { - cast_gated(gated_input_tensor, output_tensor, p, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } -} // namespace detail +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh new file mode 100644 index 0000000000..5505de6050 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -0,0 +1,722 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_mxfp8.cuh + * \brief CUDA kernels to quantize to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace quantize_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_kernel; + checkCuDriverContext(stream); + + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::COLWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + } + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh new file mode 100644 index 0000000000..cff8464903 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -0,0 +1,112 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file core_nvfp4.cuh + * \brief Core functions used in NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ + +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../util/curanddx.hpp" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif // FP4_TYPE_SUPPORTED + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +using nvfp4_scale_t = fp8e4m3; + +namespace quantization_and_transposition_SF { +#if FP4_TYPE_SUPPORTED +// Used in transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantization_and_transposition_SF + +namespace quantization_SF { +#if FP4_TYPE_SUPPORTED +// Used in non-transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantization_SF + +namespace core { + +#if FP4_TYPE_SUPPORTED +using namespace ptx; + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> &rng, + // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + random_uint4 = rng.generate4(); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace core +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh new file mode 100644 index 0000000000..bf7b535be4 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -0,0 +1,111 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_nvfp4.cuh + * \brief CUDA kernels to dequantize from NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif // FP4_TYPE_SUPPORTED + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace dequantize_kernel { +#if FP4_TYPE_SUPPORTED +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // FP4_TYPE_SUPPORTED +} // namespace dequantize_kernel + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace dequantize_kernel; + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // FP4_TYPE_SUPPORTED +} +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh new file mode 100644 index 0000000000..83ad8fd40b --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -0,0 +1,688 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_kernel { + +using namespace ptx; +using namespace quantization_SF; +using namespace core; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh new file mode 100644 index 0000000000..7322bf2655 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -0,0 +1,1287 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; + +// Each call generates 4x uint32_t random numbers +constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + // NEW: 2D Block-based scaling constants + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile + constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + const size_t warp_id = threadIdx.x / 32; + const size_t lane_id = threadIdx.x % 32; + float thread_amax = 0.0f; + const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + + // Helper function for warp reduction + auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { +#pragma unroll + for (int delta = 8; delta >= 1; delta /= 2) { + float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); + thread_amax = fmaxf(thread_amax, other_amax); + } + return thread_amax; + }; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + +#pragma unroll + for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + const size_t block_in_tile_y = block_iter; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + for (int elem = 0; elem < BLOCK_DIM; elem += 2) { + const size_t elem_0_row = block_iter * BLOCK_DIM + elem; + const size_t elem_1_row = elem_0_row + 1; + const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + const size_t elem_1_col = elem_0_col; + + const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; + const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; + + IType2 val_2x; + val_2x.x = in_sh[shmem_offset_0]; + val_2x.y = in_sh[shmem_offset_1]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); + } + + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else { + for (int elem = 0; elem < BLOCK_DIM; ++elem) { + const size_t elem_row = block_iter * BLOCK_DIM + elem; + const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + + // Bounds checking + const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); + const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); + if (!row_out_of_bounds && !col_out_of_bounds) { + const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; + float elt = static_cast(in_sh[shmem_offset]); + + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset] = static_cast(elt); + } + + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + } + // Warp reduction to get block amax + block_amax = warp_reduce_amax(thread_amax, block_in_warp); + + if (lane_id == 0 || lane_id == 16) { + block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; + } + } + + // sync thread to ensure block_amax_matrix is done storing + __syncthreads(); + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 3. Scale elements + + // Load data in + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + } + } else { + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = tid_X_rowwise; + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + } + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_kernel + +template +void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_kernel; + using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + // TODO(Frank): Is there a better way to do this? + bool return_transpose = output->has_columnwise_data(); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + if (return_transpose) { + NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu deleted file mode 100644 index 107965d342..0000000000 --- a/transformer_engine/common/util/cast.cu +++ /dev/null @@ -1,201 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/multi_stream.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "cast_kernels.cuh" -#include "dequantize_kernels.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transpose.h" - -void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - detail::quantize_helper(input, grad, output, dbias, - workspace, nullptr, stream); -} - -void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_noop); - using namespace transformer_engine; - - // Create config with noop tensor - QuantizationConfig quant_config; - quant_config.noop_tensor = noop; - - nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); -} - -void nvte_quantize_v2(const NVTETensor input, NVTETensor output, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_v2); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - detail::quantize_helper( - input, grad, output, dbias, workspace, quant_config, stream); -} - -void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr const NVTETensor activation_input = nullptr; - - detail::quantize_helper( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dsilu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_drelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dqgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dsrelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dequantize); - using namespace transformer_engine; - detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); -} - -void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, - const NVTEQuantizationConfig quant_configs, - const size_t num_tensors, cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_tensor_quantize); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - const size_t num_streams = nvte_get_num_compute_streams(); - - int num_stream_used = std::min(num_streams, num_tensors); - // wait for current stream to finish - NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); - } - - for (int i = 0; i < num_tensors; i++) { - detail::quantize_helper( - inputs[i], grad, outputs[i], dbias, workspace, nullptr, - detail::get_compute_stream(i % num_streams)); - } - - // record events on compute streams - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); - } - // wait for all compute streams to finish - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); - } -} diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh deleted file mode 100644 index b0498602b5..0000000000 --- a/transformer_engine/common/util/cast_kernels.cuh +++ /dev/null @@ -1,2188 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file cast_kernels.cuh - * \brief CUDA kernels to cast to/from FP8/MXFP8. - */ - -#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ - -#include -#include -#include -#include - -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "nvfp4_transpose.cuh" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { - -namespace mxfp8_kernel { - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 32; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; - static_assert(BUFF_DIM_Y == 32); - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - - OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } - - float block_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - - parity ^= 1; - - if constexpr (IS_DBIAS) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); - - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; - } - } - const int dbias_stride = cols; - const int dbias_offset_Y = blockIdx.y; - const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace mxfp8_kernel - -namespace nvfp4_kernel { - -using namespace ptx; - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 16; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t BUFF_DIM_Y = 32; - -constexpr size_t PACK_SIZE = 8; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 - -// Compute per-block E4M3 encoding/decoding scaling factor -__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { - constexpr float rcp_6f = 1.0f / 6.0f; - // const float S_dec_b = block_amax * rcp_6f; - // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - // return S_dec_b_fp8; - return static_cast(block_amax * rcp_6f * S_enc); -} - -#define DIRECT_SCALING_FACTORS_STORE 1 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, - const float *noop, float *const amax_ptr, - const float *const nvfp4_second_stage_scale_ptr, const size_t rows, - const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool ROWWISE_SCALING = true; - constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = - (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); - - using IType2 = typename ptx::FPx2; - - if constexpr (!COMPUTE_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; - constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; - - static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && - "Number of buffer rows must be greater or equal to the size of the columwise " - "scaling block\0"); - static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); - static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && - "Number of buffer rows must be greater or equal to the number of rowwise " - "processing threads in Y dimension\0"); - - constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size - constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - - constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; - // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of - // // threads to process one row in a single iteration - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * CHUNK_DIM_X; - const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const int tid_Y_colwise = 0; - const int tid_X_colwise = threadIdx.x; - - const int thread_offset_Y_rowwise = tid_Y_rowwise; - const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const int thread_offset_Y_colwise = tid_Y_colwise; - const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements - - const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const int col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_nvfp4 = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_mxfp8 = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t buff_size_nvfp4_scales = - CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); - - constexpr size_t in_mem = buff_size_aligned_in; - - constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); - constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); - constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); - constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); - fp8e4m3 *out_rowwise_scales_sh = - reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); - e8m0_t *out_colwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = - (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); - - float thread_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const int buff = stage % BUFFS_NUM; - const int next_stage = stage + 1; - const int stage_offset_Y = stage * BUFF_DIM_Y; - - const int buff_offset_in = buff * BUFF_IN_DIM; - const int buff_offset_out = buff * BUFF_OUT_DIM; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const int next_buff = next_stage % BUFFS_NUM; - const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const int global_offset_Y = block_offset_Y + next_stage_offset_Y; - const int global_offset_X = block_offset_X; - const int next_buff_offset = next_buff * BUFF_IN_DIM; - - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], 0); - - float block_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; - - block_amax = 0.0f; - float in_compute_colwise[SCALE_DIM_Y]; - IType in_colwise_IType[SCALE_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType block_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); - } - block_amax = static_cast(block_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); - - const int global_scales_offset_Y = scales_offset_Y_colwise + stage; - const int global_scales_offset_X = scales_offset_X_colwise; - const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - if (colwise_scale_is_within_bounds) { - scales_colwise_e8m0[scale_idx] = biased_exponent; - } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; -#pragma unroll - for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { - const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; - - const int shmem_offset_base_rowwise_in = - buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; - const int shmem_offset_base_rowwise_out = - buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; - - const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; - - block_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find NVFP4-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = - (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E4M3 scaling factor - const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); - -#if DIRECT_SCALING_FACTORS_STORE - // Check boundaries - if (rowwise_scale_is_within_bounds) { - const int scales_offset_Y = - scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = scales_offset_X_rowwise; - const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; - scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; - } -#else - const int shmem_scales_offset_Y = - stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; - const int shmem_scales_offset_X = tid_X_rowwise; - const int scale_idx = - shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; - out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; -#endif - // Compute "correct" per-block encoding scaling factor - const float block_scale_inverse = - __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 - -// 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; // Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - in01 = in_IType[w].data.elt[2 * e]; - in23 = in_IType[w].data.elt[2 * e + 1]; - } else if constexpr (IS_CACHED_ACT_OP) { - in01.x = in_cached[w].data.elt[4 * e]; - in01.y = in_cached[w].data.elt[4 * e + 1]; - in23.x = in_cached[w].data.elt[4 * e + 2]; - in23.y = in_cached[w].data.elt[4 * e + 3]; - } else { - const int j = w * PACK_SIZE + 4 * e; - in01.x = in_compute_rowwise[j]; - in01.y = in_compute_rowwise[j + 1]; - in23.x = in_compute_rowwise[j + 2]; - in23.y = in_compute_rowwise[j + 3]; - } - fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); - ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); - } - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } - } - } - - __builtin_assume(thread_amax >= 0); - __builtin_assume(block_amax >= 0); - thread_amax = fmaxf(thread_amax, block_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; - const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - -#if !DIRECT_SCALING_FACTORS_STORE - // Vectorized store of scaling factors. - // Each thread stores multiple scaling factors in one store instruction. - if constexpr (ROWWISE_SCALING) { - // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; - const int scale_idx_global = - scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; - const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; - - if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && - (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { - using ScalesVec_t = Vec; - const ScalesVec_t &scales = - *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); - scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); - } - } -#endif - - float chunk_amax = 0.0f; - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - chunk_amax = reduce_max(thread_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, chunk_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace nvfp4_kernel - -constexpr size_t FP8_CHUNK_DIM_Y = 128; -constexpr size_t FP8_CHUNK_DIM_X = 128; -constexpr size_t FP8_THREADS_PER_CHUNK = 128; -constexpr size_t FP8_BUFFERS_NUM = 2; -constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); - -constexpr size_t FP8_BUFFER_DIM_Y = 16; -constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 -constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 - -constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); - -template -__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output, - float *const dbias_workspace, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, - const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - const size_t dbias_offset_Y = blockIdx.y + tid_Y; - const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; - const bool col_out_of_bounds = my_column >= cols; - const size_t dbias_stride = cols; - - float partial_dbias = 0.f; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - - constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const size_t chunk_stage_offset_X = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); - } - } - -#pragma unroll - for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const size_t buff = iter % FP8_BUFFERS_NUM; - const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; - if (next_iter < FP8_ITERATIONS) { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); - } - } - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = row >= rows; - const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if constexpr (IS_DACT) { - if (!out_of_bounds) { - partial_dbias += elt; - } - } else { - // If no activation, elt is 0 so we can safely do this - partial_dbias += elt; - } - } - __builtin_assume(amax >= 0); - if (IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); - } - out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - - if constexpr (IS_DBIAS) { - const size_t dbias_offset_X = my_column; - const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t CHUNKS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; -constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; -constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; -constexpr size_t CHUNKS_PER_ITERATION = 32; -constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; -constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; -constexpr size_t SHMEM_BUFFERS = 2; -static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); - -template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; - const IType *input = input_ptr + block_offset; - OType *output = output_ptr + block_offset; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - - constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; ++iter) { - const size_t buff = iter % SHMEM_BUFFERS; - const size_t it_offset = iter * SHMEM_DIM; - - const size_t next_iter = iter + 1; - const size_t next_buff = next_iter % SHMEM_BUFFERS; - const size_t next_iter_offset = next_iter * SHMEM_DIM; - - if (next_iter < ITERATIONS) { - copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, - &(mbar[next_iter]), is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; - float elt = static_cast(in_sh[buff][shmem_offset]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(elt)); - out_sh[buff][shmem_offset] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - ptx::cp_async_bulk_tensor_1d_shared_to_global( - reinterpret_cast(output + it_offset), - reinterpret_cast(&out_sh[buff]), transaction_size_OUT); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read<1>(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; -template -__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, - const size_t rows, const size_t cols) { - using ComputeVec = Vec; - using OutputVec = Vec; - - const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - if (thread_id * nvec >= cols) { - return; - } - - const float *const thread_in_base = dbias_partial + thread_id * nvec; - OType *const thread_out_base = dbias_output + thread_id * nvec; - - ComputeVec ldg_vec; - ComputeVec acc_vec; - acc_vec.clear(); - for (int i = 0; i < rows; ++i) { - ldg_vec.load_from(thread_in_base + i * cols); -#pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } - - OutputVec stg_vec; -#pragma unroll - for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); - } - stg_vec.store_to(thread_out_base); -} - -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream) { - constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 - constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); - - NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); - const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); - - reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { - const size_t N = product(input.data.shape); - - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - const size_t chunks = DIVUP(N, CHUNK_SIZE); - const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - const float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(THREADS_PER_BLOCK); - const dim3 grid(blocks); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - const IType *input_ptr = reinterpret_cast(input.data.dptr); - OType *output_ptr = reinterpret_cast(output->data.dptr); - - cast_fp8_1D_kernel<<>>( - input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) - ); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, - Tensor *workspace, cudaStream_t stream) { - checkCuDriverContext(stream); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); - const size_t blocks_Y = chunks_Y; - const size_t blocks_X = chunks_X; - - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(FP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - } - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); - - cast_fp8_2D_kernel - <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, - workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, - const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - using namespace mxfp8_kernel; - checkCuDriverContext(stream); - - bool use_rowwise_scaling = output->has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - - if (use_rowwise_scaling) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - - constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; - - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -// This kernel supports only two scaling cases: -// 1. r16c0 - Rowwise NVFP4 -// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 -template -void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { - using namespace nvfp4_kernel; - using namespace ptx; - checkCuDriverContext(stream); - - NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - bool use_colwise_scaling = output->has_columnwise_data(); - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - constexpr size_t CHUNK_DIM_Y = 128; - constexpr size_t CHUNK_DIM_X = 128; - constexpr size_t THREADS_PER_CHUNK = 128; - - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = output->scale_inv.shape[1]; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); - e8m0_t *const scales_colwise_e8m0_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - - const ScalingType scaling_type = - use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const float *const nvfp4_second_stage_scale_ptr = - reinterpret_cast(output->scale.dptr); - - // Output data type is only required for the column-wise MXFP8 scaling. - // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work - const DType output_data_type = - use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, sizeof(IType) * 8); - - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); - } - - constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_nvfp4 = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_mxfp8 = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_nvfp4_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); - - constexpr size_t in_mem = buff_size_aligned_in; - - const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; - const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; - - const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; - const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; - - const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + - out_rowwise_scales_mem + out_colwise_scales_mem + - TMA_SHMEM_ALIGNMENT; - - const size_t dshmem_size = in_mem + out_mem; - - switch (scaling_type) { - case ScalingType::ROWWISE: - cudaFuncSetAttribute( - cast_nvfp4_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - - cast_nvfp4_kernel - <<>>( - tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, - scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, - nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - break; - case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( - cast_nvfp4_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - - cast_nvfp4_kernel - <<>>( - tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, - scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, - nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - break; - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace detail { - -using Empty = transformer_engine::Empty; - -__device__ inline float identity(float value, const Empty &) { return value; } - -struct DequantizeParam { - const float *scale_inv; -}; - -__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); -} - -} // namespace detail - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(noop->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream) { - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input->data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input->data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace { - -static bool is_full_tile_1D_tensor(const Tensor *const t) { - const size_t N = product(t->data.shape); - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - return isFullTile; -} - -bool dimensions_supported_by_TMA(const Tensor *const t) { - const size_t cols = t->flat_last_dim(); - constexpr size_t TMA_bytes = 16; - const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); - return cols % alignment_requirement == 0; -} - -} // namespace - -// Supported by the Arch >= 10.0 -template -void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 - cast_fp8_1D(input, output, stream); - } else { - // Unaligned - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 (+dAct) - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } else { - // Unaligned - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - } else { - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -// Supported by the Arch < 10.0 -template -void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { - // zhongboz: should we just ignore IS_ACT here? - NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); - } - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -template -void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, - Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { - // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } -} - -namespace detail { - -template -void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, - NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - const Tensor *input_tensor; - const Tensor *activation_input_tensor; - if constexpr (IS_DBIAS || IS_DACT) { - // backward - input is incoming gradient - input_tensor = convertNVTETensorCheck(grad); - activation_input_tensor = convertNVTETensor(input); - } else { - // forward = input is activation input - input_tensor = convertNVTETensorCheck(input); - activation_input_tensor = nullptr; - } - auto output_tensor = convertNVTETensorCheck(output); - auto dbias_tensor = convertNVTETensor(dbias); - auto workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && - output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4_quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4_quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for " - "2D quantization"); - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } -} - -} // namespace detail -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2f20817fb0..005a600670 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -36,6 +36,8 @@ __device__ inline OType sigmoid(const IType val, const Empty&) { return 1.f / (1.f + expf(-cval)); } +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + template __device__ inline OType dsigmoid(const IType val, const Empty& e) { const float cval = val; diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index aeac2b4a2c..6605d9cad1 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -449,13 +449,12 @@ static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); -#if CUDA_VERSION >= 12080 +#if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; static_assert(sizeof(fp4e2m1x2) == 1); static_assert(sizeof(fp4e2m1x4) == 2); -#endif // CUDA_VERSION >= 12080 // When converting to .e2m1x2 data formats, the destination operand d has .b8 type. // When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, @@ -464,7 +463,6 @@ static_assert(sizeof(fp4e2m1x4) == 2); // from input b is stored in the lower 4 bits of d. // SIMD like "Fused" cast + multiplication (x4) -#if CUDA_VERSION >= 12080 template __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, const float scale) { @@ -474,7 +472,192 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons const float x3 = static_cast(in23.y) * scale; out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); } -#endif // CUDA_VERSION >= 12080 + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} +#endif // FP4_TYPE_SUPPORTED // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,