Skip to content

Commit eb34783

Browse files
ptrendxpre-commit-ci[bot]CopilotOleg-Goncharov
authored
Overhaul the compilation for the arch-specific features (#2279)
* Added sm_120f to the build Signed-off-by: Przemek Tredak <[email protected]> * Change the arch specific handling Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Support for CUDA<12.9 Signed-off-by: Przemek Tredak <[email protected]> * Moved through the rest of the files Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Common cases Signed-off-by: Przemek Tredak <[email protected]> * Remove pure 100 from the list Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * CMake changes, (not yet working) Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Do not pass the arch-specific thing from build_tools Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Moved some of the files to arch-specific compilation Signed-off-by: Przemek Tredak <[email protected]> * Fix and also changing the order of compilation to hopefully get the compilation time lower Signed-off-by: Przemek Tredak <[email protected]> * Fix for the files overwriting custom compile properties Signed-off-by: Przemek Tredak <[email protected]> * Actually make this whole thing work Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add space to the error message Co-authored-by: Copilot <[email protected]> Signed-off-by: Przemyslaw Tredak <[email protected]> * Apply suggestions from code review Co-authored-by: Oleg Goncharov <[email protected]> Signed-off-by: Przemyslaw Tredak <[email protected]> * Fixes from review Signed-off-by: Przemek Tredak <[email protected]> * Changing the naming to be more intuitive Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing cassert include for device-side asserts Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Przemyslaw Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <[email protected]> Co-authored-by: Oleg Goncharov <[email protected]>
1 parent 66acb8e commit eb34783

File tree

7 files changed

+610
-306
lines changed

7 files changed

+610
-306
lines changed

build_tools/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,9 @@ def cuda_archs() -> str:
257257
if archs is None:
258258
version = cuda_version()
259259
if version >= (13, 0):
260-
archs = "75;80;89;90;100;100a;103a;120"
261-
elif version >= (12, 9):
262-
archs = "70;80;89;90;100;100a;103a;120"
260+
archs = "75;80;89;90;100;120"
263261
elif version >= (12, 8):
264-
archs = "70;80;89;90;100;100a;120"
262+
archs = "70;80;89;90;100;120"
265263
else:
266264
archs = "70;80;89;90"
267265
return archs

transformer_engine/common/CMakeLists.txt

Lines changed: 153 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,6 @@
55
cmake_minimum_required(VERSION 3.21)
66

77
# Language options
8-
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
9-
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
10-
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
11-
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
12-
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
13-
else ()
14-
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
15-
endif()
16-
endif()
178
set(CMAKE_CXX_STANDARD 17)
189
set(CMAKE_CUDA_STANDARD 17)
1910
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
@@ -30,8 +21,62 @@ project(transformer_engine LANGUAGES CUDA CXX)
3021

3122
# CUDA Toolkit
3223
find_package(CUDAToolkit REQUIRED)
33-
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
34-
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
24+
if (CUDAToolkit_VERSION VERSION_LESS 12.1)
25+
message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}")
26+
endif()
27+
28+
# Process GPU architectures
29+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
30+
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
31+
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
32+
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
33+
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
34+
else ()
35+
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
36+
endif()
37+
endif()
38+
39+
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
40+
set(NVTE_GENERIC_ARCHS)
41+
set(NVTE_SPECIFIC_ARCHS)
42+
43+
# Check for architecture 100
44+
list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index)
45+
if(NOT arch_100_index EQUAL -1)
46+
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100")
47+
list(APPEND NVTE_GENERIC_ARCHS "100")
48+
list(APPEND NVTE_SPECIFIC_ARCHS "100a")
49+
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
50+
list(APPEND NVTE_SPECIFIC_ARCHS "103a")
51+
endif()
52+
endif()
53+
54+
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
55+
list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index)
56+
if(NOT arch_101_index EQUAL -1)
57+
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101")
58+
list(APPEND NVTE_GENERIC_ARCHS "101")
59+
list(APPEND NVTE_SPECIFIC_ARCHS "101a")
60+
endif()
61+
62+
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
63+
list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index)
64+
if(NOT arch_110_index EQUAL -1)
65+
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110")
66+
list(APPEND NVTE_GENERIC_ARCHS "110")
67+
list(APPEND NVTE_SPECIFIC_ARCHS "110f")
68+
endif()
69+
70+
# Check for architecture 120
71+
list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index)
72+
if(NOT arch_120_index EQUAL -1)
73+
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120")
74+
list(APPEND NVTE_GENERIC_ARCHS "120")
75+
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
76+
list(APPEND NVTE_SPECIFIC_ARCHS "120f")
77+
else()
78+
list(APPEND NVTE_SPECIFIC_ARCHS "120a")
79+
endif()
3580
endif()
3681

3782
# cuDNN frontend API
@@ -78,9 +123,28 @@ endif()
78123
# Configure Transformer Engine library
79124
include_directories(${PROJECT_SOURCE_DIR}/..)
80125
set(transformer_engine_SOURCES)
81-
list(APPEND transformer_engine_SOURCES
126+
set(transformer_engine_cpp_sources)
127+
set(transformer_engine_cuda_sources)
128+
set(transformer_engine_cuda_arch_specific_sources)
129+
130+
list(APPEND transformer_engine_cpp_sources
82131
cudnn_utils.cpp
83132
transformer_engine.cpp
133+
fused_attn/fused_attn.cpp
134+
gemm/config.cpp
135+
normalization/common.cpp
136+
normalization/layernorm/ln_api.cpp
137+
normalization/rmsnorm/rmsnorm_api.cpp
138+
util/cuda_driver.cpp
139+
util/cuda_nvml.cpp
140+
util/cuda_runtime.cpp
141+
util/multi_stream.cpp
142+
util/rtc.cpp
143+
comm_gemm_overlap/userbuffers/ipcsocket.cc
144+
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
145+
comm_gemm_overlap/comm_gemm_overlap.cpp)
146+
147+
list(APPEND transformer_engine_cuda_sources
84148
common.cu
85149
multi_tensor/adam.cu
86150
multi_tensor/compute_scale.cu
@@ -92,40 +156,23 @@ list(APPEND transformer_engine_SOURCES
92156
transpose/cast_transpose_fusion.cu
93157
transpose/transpose_fusion.cu
94158
transpose/multi_cast_transpose.cu
95-
transpose/quantize_transpose_square_blockwise.cu
96159
transpose/quantize_transpose_vector_blockwise.cu
97160
transpose/swap_first_dims.cu
98-
transpose/quantize_transpose_vector_blockwise_fp4.cu
99-
activation/gelu.cu
100161
dropout/dropout.cu
101162
fused_attn/flash_attn.cu
102163
fused_attn/context_parallel.cu
103164
fused_attn/kv_cache.cu
104165
fused_attn/fused_attn_f16_max512_seqlen.cu
105166
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
106-
activation/relu.cu
107-
activation/swiglu.cu
108167
fused_attn/fused_attn_fp8.cu
109-
fused_attn/fused_attn.cpp
110168
fused_attn/utils.cu
111-
gemm/config.cpp
112169
gemm/cublaslt_gemm.cu
113-
gemm/cutlass_grouped_gemm.cu
114-
normalization/common.cpp
115-
normalization/layernorm/ln_api.cpp
116170
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
117171
normalization/layernorm/ln_fwd_cuda_kernel.cu
118-
normalization/rmsnorm/rmsnorm_api.cpp
119172
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
120173
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
121174
permutation/permutation.cu
122-
util/cast.cu
123175
util/padding.cu
124-
util/cuda_driver.cpp
125-
util/cuda_nvml.cpp
126-
util/cuda_runtime.cpp
127-
util/multi_stream.cpp
128-
util/rtc.cpp
129176
swizzle/swizzle.cu
130177
swizzle/swizzle_block_scaling.cu
131178
fused_softmax/scaled_masked_softmax.cu
@@ -139,12 +186,58 @@ list(APPEND transformer_engine_SOURCES
139186
recipe/delayed_scaling.cu
140187
recipe/fp8_block_scaling.cu
141188
recipe/nvfp4.cu
189+
comm_gemm_overlap/userbuffers/userbuffers.cu)
190+
191+
list(APPEND transformer_engine_cuda_arch_specific_sources
192+
gemm/cutlass_grouped_gemm.cu
193+
util/cast.cu
194+
activation/gelu.cu
195+
activation/relu.cu
196+
activation/swiglu.cu
197+
transpose/quantize_transpose_square_blockwise.cu
198+
transpose/quantize_transpose_vector_blockwise_fp4.cu
142199
hadamard_transform/hadamard_transform.cu
143-
hadamard_transform/hadamard_transform_cast_fusion.cu
144-
comm_gemm_overlap/userbuffers/ipcsocket.cc
145-
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
146-
comm_gemm_overlap/userbuffers/userbuffers.cu
147-
comm_gemm_overlap/comm_gemm_overlap.cpp)
200+
hadamard_transform/hadamard_transform_cast_fusion.cu)
201+
202+
# Compiling the files with the worst compilation time first to hopefully overlap
203+
# better with the faster-compiling cpp files
204+
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
205+
${transformer_engine_cuda_sources}
206+
${transformer_engine_cpp_sources})
207+
208+
# Set compile options for CUDA sources with generic architectures
209+
foreach(cuda_source IN LISTS transformer_engine_cuda_sources)
210+
set(arch_compile_options)
211+
foreach(arch IN LISTS NVTE_GENERIC_ARCHS)
212+
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
213+
endforeach()
214+
215+
if(arch_compile_options)
216+
set_property(
217+
SOURCE ${cuda_source}
218+
APPEND
219+
PROPERTY
220+
COMPILE_OPTIONS ${arch_compile_options}
221+
)
222+
endif()
223+
endforeach()
224+
225+
# Set compile options for CUDA sources with specific architectures
226+
foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
227+
set(arch_compile_options)
228+
foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS)
229+
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
230+
endforeach()
231+
232+
if(arch_compile_options)
233+
set_property(
234+
SOURCE ${cuda_source}
235+
APPEND
236+
PROPERTY
237+
COMPILE_OPTIONS ${arch_compile_options}
238+
)
239+
endif()
240+
endforeach()
148241

149242
if (NVTE_WITH_CUBLASMP)
150243
list(APPEND transformer_engine_SOURCES
@@ -249,28 +342,35 @@ target_include_directories(transformer_engine PRIVATE
249342
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
250343

251344
# Compiler options
252-
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
253-
fused_softmax/scaled_upper_triang_masked_softmax.cu
254-
fused_softmax/scaled_aligned_causal_masked_softmax.cu
255-
multi_tensor/adam.cu
256-
multi_tensor/compute_scale.cu
257-
multi_tensor/l2norm.cu
258-
multi_tensor/scale.cu
259-
multi_tensor/sgd.cu
260-
fused_attn/flash_attn.cu
261-
fused_attn/context_parallel.cu
262-
fused_attn/kv_cache.cu
263-
PROPERTIES
264-
COMPILE_OPTIONS "--use_fast_math")
345+
set(nvte_sources_with_fast_math)
346+
list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
347+
fused_softmax/scaled_upper_triang_masked_softmax.cu
348+
fused_softmax/scaled_aligned_causal_masked_softmax.cu
349+
multi_tensor/adam.cu
350+
multi_tensor/compute_scale.cu
351+
multi_tensor/l2norm.cu
352+
multi_tensor/scale.cu
353+
multi_tensor/sgd.cu
354+
fused_attn/flash_attn.cu
355+
fused_attn/context_parallel.cu
356+
fused_attn/kv_cache.cu)
357+
265358
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
266359
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
267-
set_source_files_properties(activation/gelu.cu
268-
activation/relu.cu
269-
activation/swiglu.cu
270-
util/cast.cu
271-
PROPERTIES
272-
COMPILE_OPTIONS "--use_fast_math")
360+
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
361+
activation/relu.cu
362+
activation/swiglu.cu
363+
util/cast.cu)
273364
endif()
365+
366+
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
367+
set_property(
368+
SOURCE ${cuda_source}
369+
APPEND
370+
PROPERTY
371+
COMPILE_OPTIONS "--use_fast_math")
372+
endforeach()
373+
274374
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
275375
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
276376

transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
9797
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
9898
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
9999
result_type output;
100-
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
101-
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
102-
asm volatile( \
103-
"{\n" \
104-
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
105-
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
106-
"}" \
107-
: "=h"(output_ptr[0]),
100+
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
101+
if constexpr (has_rs) {
102+
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
103+
asm volatile( \
104+
"{\n" \
105+
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
106+
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
107+
"}" \
108+
: "=h"(output_ptr[0]),
108109
"=h"(output_ptr[1])
109-
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
110+
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
110111
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
111112
"r"(rbits[0]), "r"(rbits[1]));
112-
#else
113-
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
114-
"Try recompiling with sm_XXXa instead of sm_XXX.");
115-
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
113+
} else {
114+
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
115+
"Try recompiling with sm_XXXa instead of sm_XXX.");
116+
}
116117
return output;
117118
}
118119

0 commit comments

Comments
 (0)