Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
pip install cmake==3.21.0 pybind11[global] ninja
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global] nvidia-mathdx==25.1.1
run: pip install pybind11[global]
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
run: pip install torch pybind11[global] einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion build_tools/wheel_utils/build_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ git checkout $TARGET_BRANCH
git submodule update --init --recursive

# Install deps
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel

if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"

23 changes: 0 additions & 23 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
endif()
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
if(NOT _MATHDX_LOC_MATCH)
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()

# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
Expand Down Expand Up @@ -263,7 +241,6 @@ target_link_libraries(transformer_engine PUBLIC

target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
Expand All @@ -38,15 +38,6 @@ namespace transformer_engine {
namespace detail {
namespace {

// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.

// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread());


using namespace cute;
using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor

Expand Down Expand Up @@ -502,16 +493,17 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
// Initialize RNG for tile
const size_t rng_sequence
= thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;

transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = uint4{0, 0, 0, 0};

CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
// auto acc_scale = acc_scales[v];
if constexpr (kEnableStochasticRounding) {
random_uint4 = dist.generate4(rng);
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"

namespace transformer_engine {

Expand All @@ -33,14 +33,6 @@ using std::uint8_t;

using transformer_engine::detail::TypeExtrema;

// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());

// clang-format off
/*

Expand Down Expand Up @@ -209,12 +201,15 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_
return global_encode_scale;
}

__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
__device__ __forceinline__ uint32_t
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>&
rng, // philox4x32_native_state<10>: 10 rounds of philox4_32
uint4& random_uint4, int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
random_uint4 = rng.generate4();
}

// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
Expand Down Expand Up @@ -348,9 +343,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};

transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0};

int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x

Expand Down
106 changes: 106 additions & 0 deletions transformer_engine/common/util/curanddx.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_

namespace transformer_engine {
namespace curanddx {
namespace detail {

inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U;
inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U;
inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U;
inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U;

__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int* hip) {
*hip = __umulhi(a, b);
return a * b;
}

__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1);

uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}

template <unsigned int Rounds>
__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) {
for (unsigned int i = 0; i < Rounds - 1; i++) {
c = single_round(c, k); // 1
k.x += philox4x32_w32_0;
k.y += philox4x32_w32_1;
}
return single_round(c, k); // Rounds
}

template <unsigned int Rounds>
struct philox4x32_native_state {
static constexpr unsigned int rounds = Rounds;

uint4 ctr;
uint2 key;

__forceinline__ __device__ void philox_state_incr() {
if (++ctr.x) return;
if (++ctr.y) return;
if (++ctr.z) return;
++ctr.w;
}

__forceinline__ __device__ void philox_state_incr(size_t n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);

ctr.x += nlo;
if (ctr.x < nlo) nhi++;

ctr.y += nhi;
if (nhi <= ctr.y) return;
if (++ctr.z) return;
++ctr.w;
}
Comment on lines +59 to +70
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: overflow logic has subtle carry propagation; the increment chain ctr.x += nlo; if (ctr.x < nlo) nhi++; ctr.y += nhi; if (nhi <= ctr.y) return; stops early if nhi <= ctr.y, but this skips carrying into ctr.z/ctr.w when ctr.y wraps. If ctr.y overflows after adding nhi, the carry is lost. should the condition on line 67 be if (nhi > ctr.y) (overflow occurred) instead of if (nhi <= ctr.y) return (no overflow)?


__forceinline__ __device__ void philox_state_incr_hi(size_t n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);

ctr.z += nlo;
if (ctr.z < nlo) nhi++;

ctr.w += nhi;
}

// offset is the total # of 128bits generated with a single generate4() call
__forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); }

__forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); }

__forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) {
ctr = make_uint4(0, 0, 0, 0);
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);

skip_subsequence(subsequence);
skip_offset(offset);
}

__forceinline__ __device__ uint4 generate4() {
auto tmp = multiple_rounds<Rounds>(ctr, key);
philox_state_incr();
return tmp;
}
};
} // namespace detail
} // namespace curanddx
} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
28 changes: 16 additions & 12 deletions transformer_engine/common/util/nvfp4_transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ namespace transformer_engine {
#if FP4_TYPE_SUPPORTED
namespace nvfp4_transpose {

using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());

using namespace ptx;
using nvfp4_scale_t = fp8e4m3;

Expand Down Expand Up @@ -139,12 +136,15 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const
return global_encode_scale;
}

__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
__device__ __forceinline__ uint32_t
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>
&rng, // philox4x32_native_state<10>: 10 rounds of philox4_32
uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
random_uint4 = rng.generate4();
}

// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
Expand Down Expand Up @@ -363,9 +363,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};

transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};

int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x

Expand Down Expand Up @@ -874,9 +876,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};

transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};

int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x

Expand Down
Loading