Skip to content

Commit

Permalink
Add a macro for namespace (#1419)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Jan 15, 2025
1 parent 0fcd405 commit bc482cb
Show file tree
Hide file tree
Showing 102 changed files with 723 additions and 287 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ var/
*.egg-info/
.installed.cfg
*.egg
.eggs/

# IDE-related
.idea/

# Dev
venv
venv
13 changes: 8 additions & 5 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <cutlass/numeric_types.h>

#include "namespace_config.h"
#include "hardware_info.h"
#include "flash.h"
#include "static_switch.h"
Expand All @@ -20,6 +21,7 @@
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace FLASH_NAMESPACE {

void set_params_fprop(Flash_fwd_params &params,
// sizes
Expand Down Expand Up @@ -1471,12 +1473,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
return {out, softmax_lse};
}
} // namespace FLASH_NAMESPACE

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
}
5 changes: 3 additions & 2 deletions csrc/flash_attn/src/alibi.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include <cmath>

#include "namespace_config.h"
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>

#include "utils.h"

namespace flash {
namespace FLASH_NAMESPACE {

using namespace cute;

Expand Down Expand Up @@ -71,4 +72,4 @@ struct Alibi {

};

} // namespace flash
} // namespace FLASH_NAMESPACE
5 changes: 3 additions & 2 deletions csrc/flash_attn/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#pragma once

namespace flash {
#include "namespace_config.h"
namespace FLASH_NAMESPACE {

////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -45,4 +46,4 @@ struct BlockInfo {

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
} // namespace FLASH_NAMESPACE
9 changes: 5 additions & 4 deletions csrc/flash_attn/src/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

#pragma once

#include "namespace_config.h"
#include "philox.cuh"
#include "utils.h"

namespace flash {
namespace FLASH_NAMESPACE {

struct Dropout {

Expand All @@ -26,7 +27,7 @@ struct Dropout {
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
Expand All @@ -41,7 +42,7 @@ struct Dropout {
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
Expand Down Expand Up @@ -91,4 +92,4 @@ struct Dropout {

};

} // namespace flash
} // namespace FLASH_NAMESPACE
5 changes: 5 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

#pragma once

#include "namespace_config.h"

#include <cuda.h>
#include <vector>

#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState

namespace FLASH_NAMESPACE {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
Expand Down Expand Up @@ -187,3 +190,5 @@ template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_pa
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "namespace_config.h"
#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
Loading

0 comments on commit bc482cb

Please sign in to comment.