Skip to content

Commit d7c9777

Browse files
authored
Remove nvidia-mathdx dependency (#2295)
* Remove nvidia-mathdx dep Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix SR Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add comment Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent d2945c6 commit d7c9777

File tree

8 files changed

+145
-70
lines changed

8 files changed

+145
-70
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
run: |
2020
apt-get update
2121
apt-get install -y git python3.9 pip cudnn9-cuda-12
22-
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
22+
pip install cmake==3.21.0 pybind11[global] ninja
2323
- name: 'Checkout'
2424
uses: actions/checkout@v3
2525
with:
@@ -43,7 +43,7 @@ jobs:
4343
run: |
4444
apt-get update
4545
apt-get install -y git python3.9 pip cudnn9-cuda-12
46-
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
46+
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
4747
- name: 'Checkout'
4848
uses: actions/checkout@v3
4949
with:
@@ -63,7 +63,7 @@ jobs:
6363
options: --user root
6464
steps:
6565
- name: 'Dependencies'
66-
run: pip install pybind11[global] nvidia-mathdx==25.1.1
66+
run: pip install pybind11[global]
6767
- name: 'Checkout'
6868
uses: actions/checkout@v3
6969
with:
@@ -83,7 +83,7 @@ jobs:
8383
options: --user root
8484
steps:
8585
- name: 'Dependencies'
86-
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
86+
run: pip install torch pybind11[global] einops onnxscript
8787
- name: 'Checkout'
8888
uses: actions/checkout@v3
8989
with:

build_tools/wheel_utils/build_wheels.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ git checkout $TARGET_BRANCH
2323
git submodule update --init --recursive
2424

2525
# Install deps
26-
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1
26+
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel
2727

2828
if $BUILD_METAPACKAGE ; then
2929
cd /TransformerEngine

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
# See LICENSE for license information.
44

55
[build-system]
6-
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
6+
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
77

88
# Use legacy backend to import local packages in setup.py
99
build-backend = "setuptools.build_meta:__legacy__"
10-

transformer_engine/common/CMakeLists.txt

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,28 +98,6 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
9898
# Python
9999
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
100100

101-
# NVIDIA MathDX include directory (from Python package install location)
102-
if(NOT DEFINED MATHDX_INCLUDE_DIR)
103-
execute_process(
104-
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
105-
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
106-
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
107-
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
108-
OUTPUT_STRIP_TRAILING_WHITESPACE)
109-
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
110-
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
111-
endif()
112-
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
113-
if(NOT _MATHDX_LOC_MATCH)
114-
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
115-
endif()
116-
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
117-
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
118-
endif()
119-
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
120-
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
121-
endif()
122-
123101
# Configure Transformer Engine library
124102
include_directories(${PROJECT_SOURCE_DIR}/..)
125103
set(transformer_engine_SOURCES)
@@ -263,7 +241,6 @@ target_link_libraries(transformer_engine PUBLIC
263241

264242
target_include_directories(transformer_engine PRIVATE
265243
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
266-
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
267244
target_include_directories(transformer_engine SYSTEM PRIVATE
268245
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
269246
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")

transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
#include "common/common.h"
2121
#include "common/util/cuda_runtime.h"
22+
#include "common/util/curanddx.hpp"
2223
#include "common/util/ptx.cuh"
2324
#include "common/utils.cuh"
24-
#include "curanddx.hpp"
2525
#include "cutlass/arch/barrier.h"
2626
#include "cutlass/cutlass.h"
2727
#include "cutlass/gemm/collective/builders/sm100_common.inl"
@@ -38,15 +38,6 @@ namespace transformer_engine {
3838
namespace detail {
3939
namespace {
4040

41-
// Define a cuRANDDx descriptor
42-
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
43-
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
44-
// if shared memory, if needed, is enough for the described problem, usually not applicable.
45-
46-
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
47-
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread());
48-
49-
5041
using namespace cute;
5142
using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
5243

@@ -502,16 +493,17 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
502493
// Initialize RNG for tile
503494
const size_t rng_sequence
504495
= thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
505-
RNG rng(rng_seed, rng_sequence, rng_offset);
506-
curanddx::uniform_bits dist;
496+
497+
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
498+
rng.init(rng_seed, rng_sequence, rng_offset);
507499
uint4 random_uint4 = uint4{0, 0, 0, 0};
508500

509501
CUTLASS_PRAGMA_UNROLL
510502
for (int v = 0; v < NumVecs; v++) {
511503
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
512504
// auto acc_scale = acc_scales[v];
513505
if constexpr (kEnableStochasticRounding) {
514-
random_uint4 = dist.generate4(rng);
506+
random_uint4 = rng.generate4();
515507
output_frgs[v] = StochasticNumericConverter(
516508
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
517509
compute_frgs[v],

transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#include "common/common.h"
1818
#include "common/recipe/recipe_common.cuh"
1919
#include "common/transpose/cast_transpose.h"
20+
#include "common/util/curanddx.hpp"
2021
#include "common/util/ptx.cuh"
2122
#include "common/utils.cuh"
22-
#include "curanddx.hpp"
2323

2424
namespace transformer_engine {
2525

@@ -33,14 +33,6 @@ using std::uint8_t;
3333

3434
using transformer_engine::detail::TypeExtrema;
3535

36-
// Define a cuRANDDx descriptor
37-
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
38-
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
39-
// if shared memory, if needed, is enough for the described problem, usually not applicable.
40-
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
41-
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
42-
curanddx::SM<800>() + curanddx::Thread());
43-
4436
// clang-format off
4537
/*
4638
@@ -209,12 +201,15 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_
209201
return global_encode_scale;
210202
}
211203

212-
__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
204+
__device__ __forceinline__ uint32_t
205+
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>&
206+
rng, // philox4x32_native_state<10>: 10 rounds of philox4_32
207+
uint4& random_uint4, int& rnd_idx) {
213208
if (rnd_idx == 4) {
214209
rnd_idx = 0;
215-
curanddx::uniform_bits dist;
216-
random_uint4 = dist.generate4(rng);
210+
random_uint4 = rng.generate4();
217211
}
212+
218213
// Treat uint4 as an array of 4x uint32_t elements for indexing
219214
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
220215
const uint32_t rbits = rbits_arr[rnd_idx++];
@@ -348,9 +343,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
348343
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
349344
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
350345
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
351-
RNG rng(rng_seed, rng_sequence, rng_offset);
352-
curanddx::uniform_bits dist;
353-
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};
346+
347+
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
348+
rng.init(rng_seed, rng_sequence, rng_offset);
349+
uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0};
350+
354351
int rnd_idx =
355352
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
356353

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
8+
#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
9+
10+
namespace transformer_engine {
11+
namespace curanddx {
12+
namespace detail {
13+
14+
inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U;
15+
inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U;
16+
inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U;
17+
inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U;
18+
19+
__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
20+
unsigned int* hip) {
21+
*hip = __umulhi(a, b);
22+
return a * b;
23+
}
24+
25+
__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) {
26+
unsigned int hi0;
27+
unsigned int hi1;
28+
unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0);
29+
unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1);
30+
31+
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
32+
return ret;
33+
}
34+
35+
template <unsigned int Rounds>
36+
__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) {
37+
for (unsigned int i = 0; i < Rounds - 1; i++) {
38+
c = single_round(c, k); // 1
39+
k.x += philox4x32_w32_0;
40+
k.y += philox4x32_w32_1;
41+
}
42+
return single_round(c, k); // Rounds
43+
}
44+
45+
template <unsigned int Rounds>
46+
struct philox4x32_native_state {
47+
static constexpr unsigned int rounds = Rounds;
48+
49+
uint4 ctr;
50+
uint2 key;
51+
52+
__forceinline__ __device__ void philox_state_incr() {
53+
if (++ctr.x) return;
54+
if (++ctr.y) return;
55+
if (++ctr.z) return;
56+
++ctr.w;
57+
}
58+
59+
__forceinline__ __device__ void philox_state_incr(size_t n) {
60+
unsigned int nlo = (unsigned int)(n);
61+
unsigned int nhi = (unsigned int)(n >> 32);
62+
63+
ctr.x += nlo;
64+
if (ctr.x < nlo) nhi++;
65+
66+
ctr.y += nhi;
67+
if (nhi <= ctr.y) return;
68+
if (++ctr.z) return;
69+
++ctr.w;
70+
}
71+
72+
__forceinline__ __device__ void philox_state_incr_hi(size_t n) {
73+
unsigned int nlo = (unsigned int)(n);
74+
unsigned int nhi = (unsigned int)(n >> 32);
75+
76+
ctr.z += nlo;
77+
if (ctr.z < nlo) nhi++;
78+
79+
ctr.w += nhi;
80+
}
81+
82+
// offset is the total # of 128bits generated with a single generate4() call
83+
__forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); }
84+
85+
__forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); }
86+
87+
__forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) {
88+
ctr = make_uint4(0, 0, 0, 0);
89+
key.x = (unsigned int)seed;
90+
key.y = (unsigned int)(seed >> 32);
91+
92+
skip_subsequence(subsequence);
93+
skip_offset(offset);
94+
}
95+
96+
__forceinline__ __device__ uint4 generate4() {
97+
auto tmp = multiple_rounds<Rounds>(ctr, key);
98+
philox_state_incr();
99+
return tmp;
100+
}
101+
};
102+
} // namespace detail
103+
} // namespace curanddx
104+
} // namespace transformer_engine
105+
106+
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_

transformer_engine/common/util/nvfp4_transpose.cuh

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ namespace transformer_engine {
3232
#if FP4_TYPE_SUPPORTED
3333
namespace nvfp4_transpose {
3434

35-
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
36-
curanddx::SM<800>() + curanddx::Thread());
37-
3835
using namespace ptx;
3936
using nvfp4_scale_t = fp8e4m3;
4037

@@ -139,12 +136,15 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const
139136
return global_encode_scale;
140137
}
141138

142-
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
139+
__device__ __forceinline__ uint32_t
140+
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>
141+
&rng, // philox4x32_native_state<10>: 10 rounds of philox4_32
142+
uint4 &random_uint4, int &rnd_idx) {
143143
if (rnd_idx == 4) {
144144
rnd_idx = 0;
145-
curanddx::uniform_bits dist;
146-
random_uint4 = dist.generate4(rng);
145+
random_uint4 = rng.generate4();
147146
}
147+
148148
// Treat uint4 as an array of 4x uint32_t elements for indexing
149149
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
150150
const uint32_t rbits = rbits_arr[rnd_idx++];
@@ -363,9 +363,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
363363
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
364364
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
365365
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
366-
RNG rng(rng_seed, rng_sequence, rng_offset);
367-
curanddx::uniform_bits dist;
368-
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
366+
367+
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
368+
rng.init(rng_seed, rng_sequence, rng_offset);
369+
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
370+
369371
int rnd_idx =
370372
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
371373

@@ -874,9 +876,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
874876
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
875877
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
876878
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
877-
RNG rng(rng_seed, rng_sequence, rng_offset);
878-
curanddx::uniform_bits dist;
879-
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
879+
880+
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
881+
rng.init(rng_seed, rng_sequence, rng_offset);
882+
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
883+
880884
int rnd_idx =
881885
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
882886

0 commit comments

Comments
 (0)