diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 506bc83f08..f40b281895 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1 + pip install cmake==3.21.0 pybind11[global] ninja - name: 'Checkout' uses: actions/checkout@v3 with: @@ -43,7 +43,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1 + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript - name: 'Checkout' uses: actions/checkout@v3 with: @@ -63,7 +63,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install pybind11[global] nvidia-mathdx==25.1.1 + run: pip install pybind11[global] - name: 'Checkout' uses: actions/checkout@v3 with: @@ -83,7 +83,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 + run: pip install torch pybind11[global] einops onnxscript - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 954a8f1c67..d0055b791d 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -23,7 +23,7 @@ git checkout $TARGET_BRANCH git submodule update --init --recursive # Install deps -/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1 +/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel if $BUILD_METAPACKAGE ; then cd /TransformerEngine diff --git a/pyproject.toml b/pyproject.toml index 8692ad9610..35a7c20727 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,8 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" - diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 175abd3530..e388dd794b 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -98,28 +98,6 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) -# NVIDIA MathDX include directory (from Python package install location) -if(NOT DEFINED MATHDX_INCLUDE_DIR) - execute_process( - COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx - OUTPUT_VARIABLE _PIP_SHOW_MATHDX - ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR - RESULT_VARIABLE _PIP_SHOW_MATHDX_RES - OUTPUT_STRIP_TRAILING_WHITESPACE) - if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) - message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") - endif() - string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") - if(NOT _MATHDX_LOC_MATCH) - message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") - endif() - set(MATHDX_LOCATION "${CMAKE_MATCH_1}") - set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") -endif() -if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") - message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") -endif() - # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -263,7 +241,6 @@ target_link_libraries(transformer_engine PUBLIC target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 263a32623e..12f02dba6b 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -19,9 +19,9 @@ #include "common/common.h" #include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "curanddx.hpp" #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/builders/sm100_common.inl" @@ -38,15 +38,6 @@ namespace transformer_engine { namespace detail { namespace { -// Define a cuRANDDx descriptor -// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. -// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., -// if shared memory, if needed, is enough for the described problem, usually not applicable. - -// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread()); - - using namespace cute; using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor @@ -502,8 +493,9 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // Initialize RNG for tile const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = uint4{0, 0, 0, 0}; CUTLASS_PRAGMA_UNROLL @@ -511,7 +503,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); // auto acc_scale = acc_scales[v]; if constexpr (kEnableStochasticRounding) { - random_uint4 = dist.generate4(rng); + random_uint4 = rng.generate4(); output_frgs[v] = StochasticNumericConverter( cutlass::multiplies>{}( compute_frgs[v], diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 4735fdcbe0..b49a54fbdb 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -17,9 +17,9 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "curanddx.hpp" namespace transformer_engine { @@ -33,14 +33,6 @@ using std::uint8_t; using transformer_engine::detail::TypeExtrema; -// Define a cuRANDDx descriptor -// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. -// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., -// if shared memory, if needed, is enough for the described problem, usually not applicable. -// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + - curanddx::SM<800>() + curanddx::Thread()); - // clang-format off /* @@ -209,12 +201,15 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_ return global_encode_scale; } -__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) { +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>& + rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4& random_uint4, int& rnd_idx) { if (rnd_idx == 4) { rnd_idx = 0; - curanddx::uniform_bits dist; - random_uint4 = dist.generate4(rng); + random_uint4 = rng.generate4(); } + // Treat uint4 as an array of 4x uint32_t elements for indexing const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); const uint32_t rbits = rbits_arr[rnd_idx++]; @@ -348,9 +343,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; - uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x diff --git a/transformer_engine/common/util/curanddx.hpp b/transformer_engine/common/util/curanddx.hpp new file mode 100644 index 0000000000..4d7c90a019 --- /dev/null +++ b/transformer_engine/common/util/curanddx.hpp @@ -0,0 +1,106 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ + +namespace transformer_engine { +namespace curanddx { +namespace detail { + +inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U; +inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U; +inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U; +inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U; + +__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int* hip) { + *hip = __umulhi(a, b); + return a * b; +} + +__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1); + + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; +} + +template +__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) { + for (unsigned int i = 0; i < Rounds - 1; i++) { + c = single_round(c, k); // 1 + k.x += philox4x32_w32_0; + k.y += philox4x32_w32_1; + } + return single_round(c, k); // Rounds +} + +template +struct philox4x32_native_state { + static constexpr unsigned int rounds = Rounds; + + uint4 ctr; + uint2 key; + + __forceinline__ __device__ void philox_state_incr() { + if (++ctr.x) return; + if (++ctr.y) return; + if (++ctr.z) return; + ++ctr.w; + } + + __forceinline__ __device__ void philox_state_incr(size_t n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + + ctr.x += nlo; + if (ctr.x < nlo) nhi++; + + ctr.y += nhi; + if (nhi <= ctr.y) return; + if (++ctr.z) return; + ++ctr.w; + } + + __forceinline__ __device__ void philox_state_incr_hi(size_t n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + + ctr.z += nlo; + if (ctr.z < nlo) nhi++; + + ctr.w += nhi; + } + + // offset is the total # of 128bits generated with a single generate4() call + __forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); } + + __forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); } + + __forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) { + ctr = make_uint4(0, 0, 0, 0); + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + + skip_subsequence(subsequence); + skip_offset(offset); + } + + __forceinline__ __device__ uint4 generate4() { + auto tmp = multiple_rounds(ctr, key); + philox_state_incr(); + return tmp; + } +}; +} // namespace detail +} // namespace curanddx +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh index 45fa29f0e9..629520aeb7 100644 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -32,9 +32,6 @@ namespace transformer_engine { #if FP4_TYPE_SUPPORTED namespace nvfp4_transpose { -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + - curanddx::SM<800>() + curanddx::Thread()); - using namespace ptx; using nvfp4_scale_t = fp8e4m3; @@ -139,12 +136,15 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const return global_encode_scale; } -__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> + &rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4 &random_uint4, int &rnd_idx) { if (rnd_idx == 4) { rnd_idx = 0; - curanddx::uniform_bits dist; - random_uint4 = dist.generate4(rng); + random_uint4 = rng.generate4(); } + // Treat uint4 as an array of 4x uint32_t elements for indexing const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); const uint32_t rbits = rbits_arr[rnd_idx++]; @@ -363,9 +363,11 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; - uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x @@ -874,9 +876,11 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; - uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x