From bc482cbf918c8cecd368ee54c306684645548e23 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:06:48 -0800 Subject: [PATCH] Add a macro for namespace (#1419) --- .gitignore | 3 +- csrc/flash_attn/flash_api.cpp | 13 +- csrc/flash_attn/src/alibi.h | 5 +- csrc/flash_attn/src/block_info.h | 5 +- csrc/flash_attn/src/dropout.h | 9 +- csrc/flash_attn/src/flash.h | 5 + .../src/flash_bwd_hdim128_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim128_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim128_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim128_fp16_sm80.cu | 6 +- .../src/flash_bwd_hdim160_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim160_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim160_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim160_fp16_sm80.cu | 6 +- .../src/flash_bwd_hdim192_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim192_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim192_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim192_fp16_sm80.cu | 6 +- .../src/flash_bwd_hdim256_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim256_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim256_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim256_fp16_sm80.cu | 6 +- .../src/flash_bwd_hdim32_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim32_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim32_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim32_fp16_sm80.cu | 6 +- .../src/flash_bwd_hdim64_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim64_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim64_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim64_fp16_sm80.cu | 6 +- .../src/flash_bwd_hdim96_bf16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim96_bf16_sm80.cu | 6 +- .../src/flash_bwd_hdim96_fp16_causal_sm80.cu | 6 +- .../src/flash_bwd_hdim96_fp16_sm80.cu | 6 +- csrc/flash_attn/src/flash_bwd_kernel.h | 93 ++++++------ .../src/flash_bwd_launch_template.h | 17 ++- .../src/flash_bwd_preprocess_kernel.h | 27 ++-- .../src/flash_fwd_hdim128_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim128_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim128_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim128_fp16_sm80.cu | 6 +- .../src/flash_fwd_hdim160_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim160_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim160_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim160_fp16_sm80.cu | 6 +- .../src/flash_fwd_hdim192_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim192_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim192_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim192_fp16_sm80.cu | 6 +- .../src/flash_fwd_hdim256_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim256_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim256_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim256_fp16_sm80.cu | 6 +- .../src/flash_fwd_hdim32_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim32_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim32_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim32_fp16_sm80.cu | 6 +- .../src/flash_fwd_hdim64_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim64_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim64_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim64_fp16_sm80.cu | 6 +- .../src/flash_fwd_hdim96_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim96_bf16_sm80.cu | 6 +- .../src/flash_fwd_hdim96_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_hdim96_fp16_sm80.cu | 6 +- csrc/flash_attn/src/flash_fwd_kernel.h | 137 +++++++++--------- .../src/flash_fwd_launch_template.h | 10 +- ...lash_fwd_split_hdim128_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim128_bf16_sm80.cu | 6 +- ...lash_fwd_split_hdim128_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim128_fp16_sm80.cu | 6 +- ...lash_fwd_split_hdim160_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim160_bf16_sm80.cu | 6 +- ...lash_fwd_split_hdim160_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim160_fp16_sm80.cu | 6 +- ...lash_fwd_split_hdim192_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim192_bf16_sm80.cu | 6 +- ...lash_fwd_split_hdim192_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim192_fp16_sm80.cu | 6 +- ...lash_fwd_split_hdim256_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim256_bf16_sm80.cu | 6 +- ...lash_fwd_split_hdim256_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim256_fp16_sm80.cu | 6 +- ...flash_fwd_split_hdim32_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim32_bf16_sm80.cu | 6 +- ...flash_fwd_split_hdim32_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim32_fp16_sm80.cu | 6 +- ...flash_fwd_split_hdim64_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim64_bf16_sm80.cu | 6 +- ...flash_fwd_split_hdim64_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim64_fp16_sm80.cu | 6 +- ...flash_fwd_split_hdim96_bf16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim96_bf16_sm80.cu | 6 +- ...flash_fwd_split_hdim96_fp16_causal_sm80.cu | 6 +- .../src/flash_fwd_split_hdim96_fp16_sm80.cu | 6 +- csrc/flash_attn/src/generate_kernels.py | 64 ++++---- csrc/flash_attn/src/mask.h | 7 +- csrc/flash_attn/src/namespace_config.h | 67 +++++++++ csrc/flash_attn/src/philox.cuh | 6 +- csrc/flash_attn/src/rotary.h | 5 +- csrc/flash_attn/src/softmax.h | 23 +-- csrc/flash_attn/src/utils.h | 10 +- 102 files changed, 723 insertions(+), 287 deletions(-) create mode 100644 csrc/flash_attn/src/namespace_config.h diff --git a/.gitignore b/.gitignore index 3ad20ae09..1f1f80288 100644 --- a/.gitignore +++ b/.gitignore @@ -22,9 +22,10 @@ var/ *.egg-info/ .installed.cfg *.egg +.eggs/ # IDE-related .idea/ # Dev -venv \ No newline at end of file +venv diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index b15294e6f..b8158fc94 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -12,6 +12,7 @@ #include +#include "namespace_config.h" #include "hardware_info.h" #include "flash.h" #include "static_switch.h" @@ -20,6 +21,7 @@ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +namespace FLASH_NAMESPACE { void set_params_fprop(Flash_fwd_params ¶ms, // sizes @@ -1471,12 +1473,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } return {out, softmax_lse}; } +} // namespace FLASH_NAMESPACE PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); - m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); + m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); + m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)"); + m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); + m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)"); + m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache"); } diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h index e714233e7..a65a5b379 100644 --- a/csrc/flash_attn/src/alibi.h +++ b/csrc/flash_attn/src/alibi.h @@ -1,5 +1,6 @@ #include +#include "namespace_config.h" #include #include @@ -7,7 +8,7 @@ #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -71,4 +72,4 @@ struct Alibi { }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index cf60d653c..9c8baff75 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -4,7 +4,8 @@ #pragma once -namespace flash { +#include "namespace_config.h" +namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -45,4 +46,4 @@ struct BlockInfo { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/dropout.h b/csrc/flash_attn/src/dropout.h index 4882f97d9..9077b7991 100644 --- a/csrc/flash_attn/src/dropout.h +++ b/csrc/flash_attn/src/dropout.h @@ -4,10 +4,11 @@ #pragma once +#include "namespace_config.h" #include "philox.cuh" #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { struct Dropout { @@ -26,7 +27,7 @@ struct Dropout { __forceinline__ __device__ void apply_dropout(Tensor &tensor_, int block_row_start, int block_col_start, int block_row_stride) { // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout())); + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout())); using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); @@ -41,7 +42,7 @@ struct Dropout { #pragma unroll for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast(rowcol), offset); // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); // Special implementation for 16-bit types: we duplicate the threshold to the @@ -91,4 +92,4 @@ struct Dropout { }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 9a503998c..8ffbb62d6 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -4,11 +4,14 @@ #pragma once +#include "namespace_config.h" + #include #include #include // For at::Generator and at::PhiloxCudaState +namespace FLASH_NAMESPACE { constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -187,3 +190,5 @@ template void run_mha_fwd_(Flash_fwd_pa template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu index 13132e86d..b56b168ab 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu index 85a5dc88e..a82d58ee6 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu index 5d27cd97b..8e60b0428 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu index 2d7ddf46b..5567cb263 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu index c18a78c76..e34dd2454 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu index 1b6173725..5089d988d 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu index a511162dc..0272c5797 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu index c9ce19acb..d3d5d98d1 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu index f492a7171..23ecfe8d0 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu index 2df58daa2..823171a13 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu index 69cad5ae4..3f25fad5f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu index 3d4cab58b..7de57f4ed 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu index 692744597..f5327c989 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu index d718ec88b..62b05332c 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu index 551c695e0..b0b837b24 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu index a58770026..f3f05b1c9 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu index 1282939a0..ddf5f1a5f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu index d6d403638..ecd256443 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu index 60aa2d60b..684c97378 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu index b06d50eaa..f6ba14b4a 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu index 52b93be9d..ba2d81f5f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu index 09d9e2b75..a2e778032 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu index 5a4ea5f46..d65d9659a 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu index fb115ff76..d3aa81393 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu index 5f4c26a47..7796ccb74 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu index 224213d79..2238ef52b 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu index d0349014f..6761b3ebd 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu index 663fc8592..bdad1121b 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_bwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 6023a6a51..8f42f0ae1 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -4,6 +4,7 @@ #pragma once +#include "namespace_config.h" #include #include @@ -19,7 +20,7 @@ #include "alibi.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -352,10 +353,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); return; @@ -371,28 +372,28 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (Kernel_traits::Is_V_in_regs) { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - flash::cp_async_fence(); + FLASH_NAMESPACE::cp_async_fence(); } Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrO = make_fragment_like(tdOgO); if (!Is_first) { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); } else { - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); } - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); @@ -417,15 +418,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); if (!Kernel_traits::Is_V_in_regs) { - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } - flash::cp_async_fence(); + FLASH_NAMESPACE::cp_async_fence(); // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } if (Is_first) { @@ -442,14 +443,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); } - flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t, + FLASH_NAMESPACE::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t, bidb, bidh, tidx, params.h); clear(acc_dv); clear(acc_dk); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); + FLASH_NAMESPACE::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) @@ -468,21 +469,21 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); // } // if (cute::thread0()) { print(tSrK); } - flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, + FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); if constexpr (Is_softcap) { - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } // Softcapping - calculating dTanh and scaling dS later with it [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); if constexpr (Is_softcap) { - flash::calculate_dtanh(scores, dtanh, params.softcap); + FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } // Alibi @@ -500,7 +501,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // So we need to mask out the elements beyond actual_seqlen_k. if (!Is_causal && !Is_local) { if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { - flash::apply_mask(scores, binfo.actual_seqlen_k, + FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); } } else if (Is_causal) { @@ -510,7 +511,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // But we still want to mask out elements beyond actual_seqlen_k. if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_q, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, @@ -520,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_q, AtomLayoutMS * 16, params.window_size_left, params.window_size_right); @@ -530,7 +531,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. - flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); if constexpr (Is_dropout) { int warp_id = tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; @@ -543,11 +544,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } // Convert scores from fp32 to fp16/bf16 Tensor rP = !Is_dropout - ? flash::convert_type(acc_s) - : flash::convert_type_relu(acc_s); + ? FLASH_NAMESPACE::convert_type(acc_s) + : FLASH_NAMESPACE::convert_type_relu(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } @@ -560,7 +561,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA clear(acc_dp); - // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), flash::convert_layout_acc_rowcol(acc_dp.layout())); + // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout())); // #pragma unroll // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { // #pragma unroll @@ -571,7 +572,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(dP_sum); } - flash::gemm( + FLASH_NAMESPACE::gemm( acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV ); @@ -612,13 +613,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tSsQ.data() = tSsQ.data() + sQ_offset; // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); - flash::cp_async_fence(); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); } Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); // Convert dS from fp32 to fp16 - Tensor tdSrdS = flash::convert_type(dS_reshaped); + Tensor tdSrdS = FLASH_NAMESPACE::convert_type(dS_reshaped); // if (cute::thread0()) { print(tPrP); } Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); @@ -626,10 +627,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Layout p_l = tPrP.layout(); // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, + // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } // if (cute::thread0()) { print(acc_dv); } @@ -641,15 +642,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); if (Is_first) { tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); - flash::copy(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); - flash::copy(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); } else { - flash::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); - flash::cp_async_fence(); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); } } - flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, + FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); // if (cute::thread0()) { print(acc_dq); } @@ -678,12 +679,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dq from fp32 to fp16 - Tensor rdQ = flash::convert_type(acc_dq); + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } - flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, + FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0()) { print(acc_dk); } if (Double_buffer) { // Double buffer for sQ @@ -693,8 +694,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); - flash::cp_async_fence(); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); } if (Is_first && m_block > m_block_min) { @@ -730,8 +731,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dv from fp32 to fp16 - Tensor rdK = flash::convert_type(acc_dk); - Tensor rdV = flash::convert_type(acc_dv); + Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); + Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) @@ -782,10 +783,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 3b79a01c5..b719cf988 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -4,6 +4,7 @@ #pragma once +#include "namespace_config.h" #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" @@ -12,6 +13,8 @@ #include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_kernel.h" +namespace FLASH_NAMESPACE { + // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define ARCH_SUPPORTS_FLASH @@ -30,7 +33,7 @@ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { #if defined(ARCH_SUPPORTS_FLASH) - flash::compute_dq_dk_dv(params); + FLASH_NAMESPACE::compute_dq_dk_dv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -39,7 +42,7 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -48,22 +51,22 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool template __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { - flash::compute_dot_do_o(params); + FLASH_NAMESPACE::compute_dot_do_o(params); } template __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { - flash::clear_dKVaccum(params); + FLASH_NAMESPACE::clear_dKVaccum(params); } template __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { - flash::convert_dQ(params, nsplits); + FLASH_NAMESPACE::convert_dQ(params, nsplits); } template __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { - flash::convert_dKV(params); + FLASH_NAMESPACE::convert_dKV(params); } template @@ -321,3 +324,5 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { } }); } + +} // namespace FLASH_NAMESPACE { diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index c8e307417..016a01070 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -4,6 +4,7 @@ #pragma once +#include "namespace_config.h" #include #include @@ -14,7 +15,7 @@ #include "kernel_traits.h" #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -32,8 +33,8 @@ inline __device__ void dot_do_o(Tensor const &do_, Tensor(do_.layout()), get<2>(do_.layout())))); Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); - Tensor do_fp32 = flash::convert_type(do_reshaped); - Tensor o_fp32 = flash::convert_type(o_reshaped); + Tensor do_fp32 = FLASH_NAMESPACE::convert_type(do_reshaped); + Tensor o_fp32 = FLASH_NAMESPACE::convert_type(o_reshaped); #pragma unroll for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); @@ -41,8 +42,8 @@ inline __device__ void dot_do_o(Tensor const &do_, Tensor(do_reshaped); ni++) { dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); } - flash::SumOp sum_op; - dP_sum_cur = flash::Allreduce::run(dP_sum_cur, sum_op) * scale; + FLASH_NAMESPACE::SumOp sum_op; + dP_sum_cur = FLASH_NAMESPACE::Allreduce::run(dP_sum_cur, sum_op) * scale; if (threadIdx.x % THREADS_PER_ROW == 0) { dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; } @@ -116,10 +117,10 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrO = make_fragment_like(tdOgO); - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final @@ -244,7 +245,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dq from fp32 to fp16 - Tensor rdQ = flash::convert_type(acc_dq); + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); __syncthreads(); @@ -257,7 +258,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { #pragma unroll for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM ); } @@ -349,8 +350,8 @@ inline __device__ void convert_dKV(const Params ¶ms) { acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; } // Convert acc_dk from fp32 to fp16 - Tensor rdK = flash::convert_type(acc_dk); - Tensor rdV = flash::convert_type(acc_dv); + Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); + Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); @@ -367,10 +368,10 @@ inline __device__ void convert_dKV(const Params ¶ms) { #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu index 9383c1024..baca4777b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu index f03abda48..230059eab 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu index c616628c8..ab3ce5500 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu index 4ff6b9fbf..030328544 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu index d6d4371bf..27d9e9d8a 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu index 5af68ac38..943e508eb 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu index 1ef511a6b..92904627b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu index 96abfbd8a..7b3749e25 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu index 077d25d09..c04b7b9e0 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu index ea5f265fe..72468c382 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu index a4a7bc242..89ffeb8e1 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu index c30c4a14f..729660e8e 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu index f84e978c9..c20084704 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu index c52f0417b..2f8ee2496 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu index f96f7edc6..fe8e55083 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu index 9c7c6b93d..6a99c93a4 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu index e21d0408c..8cb6578f5 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu index f377a5b8f..ba7adb8f4 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu index 74e4d66ae..b0cb6844d 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu index e85db18e3..cfbcf6c52 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu index 9297e8bb6..7d703bb45 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu index 8364b1e7e..e0ab7cb59 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu index 1c6ed7ef0..d4c869f7e 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu index 3c87573ba..fa2863579 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu index 49fae856a..d7c626cd9 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu index c5af1cf63..9b2d5f7a0 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu index b0d6c9928..9eb4fdd6c 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu index c97aa33f8..0a1ea1a93 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 94352e5ac..1ba07da15 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -4,6 +4,7 @@ #pragma once +#include "namespace_config.h" #include "philox_unpack.cuh" // For at::cuda::philox::unpack #include @@ -20,7 +21,7 @@ #include "dropout.h" #include "rotary.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -66,7 +67,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; auto seed_offset = at::cuda::philox::unpack(params.philox_args); - flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, bidb, bidh, tidx, params.h); // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might @@ -115,7 +116,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll @@ -246,7 +247,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Prologue // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } @@ -255,7 +256,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // // if (cute::thread0()) { print(sQNoSwizzle); } if (Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M @@ -265,14 +266,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<1>(); + FLASH_NAMESPACE::cp_async_wait<1>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M @@ -281,10 +282,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - flash::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -301,37 +302,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -343,7 +344,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { @@ -361,9 +362,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration @@ -377,23 +378,23 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -405,7 +406,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { @@ -423,8 +424,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -432,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 - Tensor rO = flash::convert_type(acc_o); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); @@ -487,7 +488,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } @@ -563,7 +564,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll @@ -730,18 +731,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto tKgK_data = tKgK.data(); auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - flash::copy_w_min_idx( + FLASH_NAMESPACE::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { - flash::copy_w_min_idx( + FLASH_NAMESPACE::copy_w_min_idx( tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); } else { if (params.is_rotary_interleaved) { // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_interleaved( + FLASH_NAMESPACE::copy_rotary_interleaved( tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); @@ -749,7 +750,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); } else { // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_contiguous( + FLASH_NAMESPACE::copy_rotary_contiguous( tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); @@ -784,7 +785,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); @@ -807,12 +808,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); if (params.is_rotary_interleaved) { - flash::copy_rotary_interleaved( + FLASH_NAMESPACE::copy_rotary_interleaved( tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); } else { - flash::copy_rotary_contiguous( + FLASH_NAMESPACE::copy_rotary_contiguous( tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); @@ -821,21 +822,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); - // flash::cp_async_wait<0>(); + // FLASH_NAMESPACE::cp_async_wait<0>(); // __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); clear(acc_o); - flash::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -852,7 +853,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV @@ -866,22 +867,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } @@ -889,7 +890,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); @@ -905,7 +906,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -918,12 +919,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -936,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (block_table == nullptr) { @@ -948,18 +949,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { // Advance gK @@ -972,7 +973,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -983,12 +984,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -1005,7 +1006,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons >; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(acc_o); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) @@ -1064,7 +1065,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } @@ -1087,7 +1088,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1101,7 +1102,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1242,7 +1243,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM ); #pragma unroll @@ -1262,7 +1263,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } // if (cute::thread0()) { print_tensor(tOrO); } - Tensor rO = flash::convert_type(tOrO); + Tensor rO = FLASH_NAMESPACE::convert_type(tOrO); // Write to gO #pragma unroll for (int m = 0; m < size<1>(rO); ++m) { @@ -1290,4 +1291,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } } -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index b04667c55..227f3c257 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -3,6 +3,7 @@ ******************************************************************************/ #pragma once +#include "namespace_config.h" #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" @@ -10,6 +11,8 @@ #include "flash.h" #include "flash_fwd_kernel.h" +namespace FLASH_NAMESPACE { + // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define ARCH_SUPPORTS_FLASH @@ -29,7 +32,7 @@ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - flash::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -37,7 +40,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { #if defined(ARCH_SUPPORTS_FLASH) - flash::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -45,7 +48,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } template @@ -327,3 +330,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu index a959c9ceb..40559c640 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu index e608e308e..48500b8f1 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu index 3dd74e273..355902924 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu index addacedf4..6aa638de8 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu index 8ace7bda9..f5167b333 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu index 1e133ec1a..ee02db1a3 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu index 1723c69e0..2b0472038 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu index 892d2352a..2b833bd53 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu index d07ee0af2..979deee41 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu index 23cfa59d5..236365e4f 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu index 273a28442..9c4420fa8 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu index 0f588d1f4..872f5ced8 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu index 370fe9ca3..8fee9f57b 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu index 508f07f7d..6adcb1bf2 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu index 019ded67f..df05869f7 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu index 708f5542a..51bd8e4d7 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu index 5a205b7e7..fa340d6f0 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu index 2c576f118..0f2adec7a 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu index 484a15e93..345551033 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu index 5474ae89d..ec9523de0 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu index 8c7da41dd..750c69fcc 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu index 93f29dea8..a1b26d84f 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu index 1e2e12b8c..306116710 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu index 16c34ed3f..aeda6bfdd 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu index 50080c47e..d55eb4039 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu index ae56ddd4c..a139c0743 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu index ed305767e..8e6634323 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu index 022064656..2a874bf60 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu @@ -1,7 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 119e34956..7b2130bab 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -1,8 +1,3 @@ -# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602 - -# This file is run to generate the kernel instantiations for the flash_attn kernels -# They are written to several files in order to speed up compilation - import argparse import itertools from dataclasses import dataclass @@ -17,27 +12,40 @@ SM = [80] # Sm80 kernels support up to HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256] IS_CAUSAL = ["false", "true"] -KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" +NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' + +def get_fwd_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ template<> void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} -""" -KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" +}} // namespace FLASH_NAMESPACE""" + +def get_fwd_split_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); -""" -KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" +}} // namespace FLASH_NAMESPACE""" + +def get_bwd_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ template<> void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} -""" +}} // namespace FLASH_NAMESPACE""" @dataclass class Kernel: @@ -49,37 +57,33 @@ class Kernel: @property def template(self) -> str: - if self.direction == "fwd": - return KERNEL_IMPL_TEMPLATE_FWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal - ) - elif self.direction == "bwd": - return KERNEL_IMPL_TEMPLATE_BWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal - ) - else: - return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal - ) + template_funcs = { + "fwd": get_fwd_template, + "bwd": get_bwd_template, + "fwd_split": get_fwd_split_template + } + template_func = template_funcs[self.direction] + return template_func().format( + DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, + IS_CAUSAL=self.is_causal + ) @property def filename(self) -> str: return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" - def get_all_kernels() -> List[Kernel]: for direction in ["fwd", "fwd_split", "bwd"]: for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) - def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: prelude = """// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py"\n -""" - (autogen_dir / kernel.filename).write_text(prelude + kernel.template) - +// This file is auto-generated. See "generate_kernels.py"\n""" + content = prelude + kernel.template + (autogen_dir / kernel.filename).write_text(content) def main(output_dir: Optional[str]) -> None: if output_dir is None: @@ -90,13 +94,11 @@ def main(output_dir: Optional[str]) -> None: for kernel in get_all_kernels(): write_kernel(kernel, output_dir) - if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate_kernels", description="Generate the flash_attention kernels template instantiations", ) - # Set an optional output directory parser.add_argument( "-o", "--output_dir", diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 7ba435a37..544065d1d 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -3,10 +3,11 @@ ******************************************************************************/ #pragma once +#include "namespace_config.h" #include -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -137,7 +138,7 @@ struct Mask { // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } if constexpr (Need_masking) { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; @@ -210,4 +211,4 @@ struct Mask { }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/namespace_config.h b/csrc/flash_attn/src/namespace_config.h new file mode 100644 index 000000000..a6fad57b1 --- /dev/null +++ b/csrc/flash_attn/src/namespace_config.h @@ -0,0 +1,67 @@ +/** + * @file flash_namespace_config.h + * @brief Configuration file for Flash namespace management and isolation + * + * This header provides configuration macros for managing the Flash namespace + * across a codebase. It allows for flexible namespace naming and provides + * utilities for namespace declaration and scoping. + * + * Usage Examples: + * + * 1. Basic namespace wrapping: + * @code + * BEGIN_FLASH_NAMESPACE + * class FlashDevice { + * // Implementation + * }; + * END_FLASH_NAMESPACE + * @endcode + * + * 2. Accessing types within the namespace: + * @code + * FLASH_NAMESPACE_ALIAS(FlashDevice) device; + * @endcode + * + * 3. Defining content within namespace scope: + * @code + * FLASH_NAMESPACE_SCOPE( + * struct Configuration { + * uint32_t size; + * bool enabled; + * }; + * ) + * @endcode + * + * 4. Custom namespace name: + * @code + * #define FLASH_NAMESPACE custom_flash + * #include "flash_namespace_config.h" + * @endcode + * + * Configuration: + * - The default namespace is 'flash' if FLASH_NAMESPACE is not defined + * - Define FLASH_NAMESPACE before including this header to customize the + * namespace name + * + * Best Practices: + * - Include this header in all files that need access to the Flash namespace + * + */ +#pragma once + +#ifndef FLASH_NAMESPACE_CONFIG_H +#define FLASH_NAMESPACE_CONFIG_H + +// Set default namespace to flash +#ifndef FLASH_NAMESPACE +#define FLASH_NAMESPACE flash +#endif + +#define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name + +#define FLASH_NAMESPACE_SCOPE(content) \ + namespace FLASH_NAMESPACE { \ + content \ + } + +#endif // FLASH_NAMESPACE_CONFIG_H diff --git a/csrc/flash_attn/src/philox.cuh b/csrc/flash_attn/src/philox.cuh index cd7e4d2fa..5205f4542 100644 --- a/csrc/flash_attn/src/philox.cuh +++ b/csrc/flash_attn/src/philox.cuh @@ -2,7 +2,9 @@ #pragma once // Philox CUDA. -namespace flash { +#include "namespace_config.h" + +namespace FLASH_NAMESPACE { struct ull2 { unsigned long long x; @@ -48,4 +50,4 @@ __forceinline__ __device__ uint4 philox(unsigned long long seed, return output; } -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/rotary.h b/csrc/flash_attn/src/rotary.h index 7f1614ad2..dbae24c62 100644 --- a/csrc/flash_attn/src/rotary.h +++ b/csrc/flash_attn/src/rotary.h @@ -6,11 +6,12 @@ #include +#include "namespace_config.h" #include "utils.h" //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -149,4 +150,4 @@ __forceinline__ __device__ void copy_rotary_contiguous(Tensor //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index ebf1b0979..01589aded 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -10,10 +10,11 @@ #include +#include "namespace_config.h" #include "philox.cuh" #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -135,18 +136,18 @@ struct Softmax { template __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { - flash::template reduce_max(scores, row_max); - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - flash::reduce_sum(scores, row_sum); + FLASH_NAMESPACE::template reduce_max(scores, row_max); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); - flash::template reduce_max(scores, row_max); + FLASH_NAMESPACE::template reduce_max(scores, row_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { @@ -158,10 +159,10 @@ struct Softmax { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. - flash::reduce_sum(scores, row_sum); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); } }; @@ -170,7 +171,7 @@ struct Softmax { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { @@ -185,4 +186,4 @@ struct Softmax { }; }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index b7408ec44..a7729aede 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -21,9 +21,11 @@ #include #include +#include "namespace_config.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace flash { +namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -268,8 +270,8 @@ __forceinline__ __device__ auto convert_type_relu(Tensor const & } Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); #else - Tensor out = flash::convert_type(tensor); - flash::relu_(out); + Tensor out = FLASH_NAMESPACE::convert_type(tensor); + FLASH_NAMESPACE::relu_(out); #endif return out; } @@ -408,4 +410,4 @@ __forceinline__ __device__ void calculate_dtanh(Tensor &src_te //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace FLASH_NAMESPACE