Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for qk hidden dim different from v hidden dim #1166

Open
wants to merge 54 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
5eeef1e
intermediate save
xiayuqing0622 Aug 12, 2024
331a601
support var dim
xiayuqing0622 Aug 13, 2024
02da101
modify readme
xiayuqing0622 Aug 13, 2024
2bce87c
compatible
xiayuqing0622 Aug 15, 2024
51a8bcb
Merge branch 'main' into dim
xiayuqing0622 Aug 19, 2024
e8b4082
test_head_dim
smallscientist1 Aug 19, 2024
ebf0b16
add test headdim
smallscientist1 Aug 20, 2024
ab35fc2
fix some config bug
smallscientist1 Aug 20, 2024
4e94c20
update test headdim
smallscientist1 Aug 20, 2024
e31b6a4
Merge branch 'Dao-AILab:main' into dim
xiayuqing0622 Aug 20, 2024
89dbe52
update test headdim splitkv
smallscientist1 Aug 20, 2024
fc094a4
Merge commit '89dbe521b48000ee4f3d942d7c3498c698817159' into dim
smallscientist1 Aug 20, 2024
d11b7ae
update ReadMe.md
smallscientist1 Aug 20, 2024
21ca4bc
remove unused file
smallscientist1 Aug 20, 2024
4c3462a
revert Readme
smallscientist1 Aug 20, 2024
f63411d
create bench headdim
smallscientist1 Aug 21, 2024
3e0c7c4
update bench result
smallscientist1 Aug 22, 2024
3caa059
update Readme
smallscientist1 Aug 22, 2024
493a430
reorg code to reduce compile time
smallscientist1 Aug 22, 2024
0607e6c
update (128,256) config
smallscientist1 Aug 22, 2024
fd6fc29
add (192,128)
smallscientist1 Aug 26, 2024
b6d7493
add config (192,128)
smallscientist1 Aug 26, 2024
85fb8d2
fix bug
smallscientist1 Aug 26, 2024
f0644c2
fix bug backward
smallscientist1 Aug 27, 2024
0092285
fix bug
smallscientist1 Aug 27, 2024
6e88a4d
Add support for dim(192,128) (#1)
smallscientist1 Aug 27, 2024
255cd5a
add optional dim compile
smallscientist1 Aug 28, 2024
e666f96
Merge branch 'Dao-AILab:main' into dim
xiayuqing0622 Sep 3, 2024
00979f5
support different head kv
smallscientist1 Sep 4, 2024
feeab17
add test_head
smallscientist1 Sep 4, 2024
18b309d
update flash api head
smallscientist1 Sep 4, 2024
6909ab4
fix interface bug
smallscientist1 Sep 4, 2024
3c8bb2b
Merge pull request #2 from xiayuqing0622/head
smallscientist1 Sep 4, 2024
5f26eb0
update README
smallscientist1 Sep 4, 2024
536a8cc
benchmark head_headdim
smallscientist1 Sep 4, 2024
ca6335d
fix bench bug
smallscientist1 Sep 4, 2024
def41c0
fix bug for numhead
smallscientist1 Sep 5, 2024
6e8d537
add autotuner
smallscientist1 Sep 6, 2024
83fd7a5
basetuner fwd
smallscientist1 Sep 6, 2024
7cf4858
update autotuner FLashFwd
smallscientist1 Sep 10, 2024
1ca8397
autotuner fwd
smallscientist1 Sep 10, 2024
1e5c49d
update code
smallscientist1 Sep 10, 2024
409bdde
update autotuner log
smallscientist1 Sep 12, 2024
d4b620a
update tunner
smallscientist1 Sep 12, 2024
be21a0a
fix bug kernel launch
smallscientist1 Sep 12, 2024
90fa651
update autotuner tile space
smallscientist1 Sep 18, 2024
1ba39eb
update cutlass bugfix
smallscientist1 Sep 18, 2024
c5fa3c9
add autotuner doc
smallscientist1 Sep 18, 2024
31ea0bb
Merge pull request #3 from xiayuqing0622/dim_autotuner
smallscientist1 Sep 18, 2024
b09eaee
update readme
smallscientist1 Sep 18, 2024
cd9fee4
update autotuner
smallscientist1 Sep 19, 2024
014c349
update readme
smallscientist1 Sep 19, 2024
cd91625
Merge branch 'dim_pr' into dim_pr1
smallscientist1 Sep 19, 2024
d578cff
Merge pull request #4 from xiayuqing0622/dim_pr1
smallscientist1 Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
autotuner fwd
smallscientist1 committed Sep 10, 2024
commit 1ca8397f63b8c5ac72156491ac7242b4670813ad
31 changes: 23 additions & 8 deletions autotuner/base_tunner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import ctypes
import os
from concurrent.futures import ThreadPoolExecutor
# import multiprocessing
# from functools import partial
import tempfile
import subprocess
import importlib.util

import ctypes
import torch
from configs.base_config import BaseConfig
from configs import BaseConfig, supported_configs

import pprint
import json
@@ -70,8 +72,8 @@ class BaseTunner:
def __init__(self, arch, torch_array: list, op_name, tempdir):
self.arch = arch
self.torch_array = torch_array
self.Br_list = [32, 64, 128, 256]
self.Bc_list = [32, 64, 128, 256]
self.Br_list = [32, 64, 128] # [32, 64, 128, 256]
self.Bc_list = [32, 64, 128] # [32, 64, 128, 256]

self.template_dir = "autotuner/template"
self.op_name = op_name
@@ -80,6 +82,7 @@ def __init__(self, arch, torch_array: list, op_name, tempdir):
"dim_qk": torch_array[0].shape[-1],
"dim_v": torch_array[2].shape[-1]
}
# TODO: causal, dropout
self.shape_config = ShapeConfig(torch_array[0].shape[-1],torch_array[2].shape[-1])
self.tempdir = tempdir

@@ -89,11 +92,12 @@ def compile(self, configs:list, timeout: float = None):
code_emitter.generate_code(self.shape_config, configs)


def profile(self, config:BaseConfig, device="cuda:0") -> float:
def profile(self, config:BaseConfig, device="cuda:0", repeat=30) -> float:
spec = importlib.util.spec_from_file_location("flash_attn_func", self.tempdir+"/"+config.output_dir+"/flash_attn_profile_interface.py")
flash_attn_func = importlib.util.module_from_spec(spec)
spec.loader.exec_module(flash_attn_func)
latency = profile_fwd(flash_attn_func)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
flash_attn_func = mod.flash_attn_func
latency = profile_fwd(flash_attn_func, self.shape_config.Kd, self.shape_config.D, is_bf16=self.shape_config.is_bf16, causal=self.shape_config.is_causal, device=device, repeats=repeat)
if latency < 0:
latency = 1e8
# remove lib
@@ -108,7 +112,7 @@ def get_tuned_configs(self):
for Bc in self.Bc_list:
cur_configs = self.generate_configs(Br,Bc,dim_qk,dim_v)
for cur_config in cur_configs:
if self.operation == "flash_fwd" and self.validate_register_fuse(cur_config):
if self.op_name == "flash_fwd" and self.validate_register_fuse(cur_config):
configs.append(cur_config)
else: # BWD
if self.validate_kernel(cur_config):
@@ -139,6 +143,17 @@ def tune(self, log_path="./logs/"):
# cresults = self.compile(configs,src_dir.name,timeout=1200)
# cresults = self.compile_parallel(configs,src_dir.name,timeout=120)
self.compile(configs,timeout=120)

# warm up (parallel compile module)
# module name must be different in api.py
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
latencys = executor.map(self.profile, configs, ["cuda:0" for _ in range(len(configs))], [1 for _ in range(len(configs))])
# with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
# latencys = executor.map(_profile,[self.tempdir for _ in range(len(configs))],[self.shape_config for _ in range(len(configs))], configs, ["cuda:0" for _ in range(len(configs))], [1 for _ in range(len(configs))])
# multiprocessing.set_start_method('spawn', force=True)
# pool = multiprocessing.Pool(os.cpu_count())
# outs = pool.map(partial(self.profile, repeat=1), configs)

profile_dict = {}
latency = 1e8
best_config = None
4 changes: 3 additions & 1 deletion autotuner/code_emitter.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,8 @@ def __init__(self, template_dir, output_dir) -> None:
]
self.kernel_file_list = [
"flash_fwd.h",
"flash_profile.h",
"flash_fwd_launch_template_profile.h"
]

def generate_code(self, shape_config:ShapeConfig, configs:list[BaseConfig]):
@@ -41,7 +43,6 @@ def generate_code(self, shape_config:ShapeConfig, configs:list[BaseConfig]):
f.write(code_template)

# generate kernel code
# TODO: parallelize
for config in configs:
kernel_code_dir = Path(output_dir) / Path(config.output_dir)
if not kernel_code_dir.exists():
@@ -59,6 +60,7 @@ def generate_code(self, shape_config:ShapeConfig, configs:list[BaseConfig]):
code_template = f.read()
code_template = code_template.replace("OUTPUT_DIR", f"\"{str(output_dir)}\"")
code_template = code_template.replace("OUTPUT_KERNEL_DIR", f"\"{str(kernel_code_dir)}\"")
code_template = code_template.replace("CONFIG_NAME", f"\"{str(config)}\"")
with open(Path(kernel_code_dir) / Path("flash_attn_profile_interface.py"), "w") as f:
f.write(code_template)

6 changes: 6 additions & 0 deletions autotuner/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .base_config import BaseConfig
from .fwd_config import FlashFwdConfig

supported_configs = {
"flash_fwd": FlashFwdConfig,
}
2 changes: 1 addition & 1 deletion autotuner/configs/base_config.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ def __init__(self, Kd, D, Br, Bc, Nwarps=8) -> None:
self.template_dir = None

def __repr__(self) -> str:
return "Config(Kd={}, D={}, Br={}, Bc={}, Nwarps={}".format(self.Kd, self.D, self.Br, self.Bc, self.Nwarps)
return "Config(Kd={}, D={}, Br={}, Bc={}, Nwarps={})".format(self.Kd, self.D, self.Br, self.Bc, self.Nwarps)

def __str__(self) -> str:
return f"{self.Kd}_{self.D}_{self.Br}_{self.Bc}_{self.Nwarps}"
2 changes: 1 addition & 1 deletion autotuner/configs/fwd_config.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ def __init__(self, Kd, D, Br, Bc, Nwarps=8, isQinRegs:bool = False, SharedQKSmem
self.template_dir = os.path.join(os.path.dirname(__file__), "../../../csrc/kernels/attention")

def __repr__(self) -> str:
return "Config(Kd={}, D={}, Br={}, Bc={}, Nwarps={}, isQinRegs={}, SharedQKSmem={}".format(self.Kd, self.D, self.Br, self.Bc, self.Nwarps, self.isQinRegs, self.SharedQKSmem)
return "Config(Kd={}, D={}, Br={}, Bc={}, Nwarps={}, isQinRegs={}, SharedQKSmem={})".format(self.Kd, self.D, self.Br, self.Bc, self.Nwarps, self.isQinRegs, self.SharedQKSmem)

def __str__(self) -> str:
return f"{self.Kd}_{self.D}_{self.Br}_{self.Bc}_{self.Nwarps}_{self.isQinRegs}_{self.SharedQKSmem}"
12 changes: 7 additions & 5 deletions autotuner/template/flash_attn_profile_interface.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@


flash_attn_cuda = torch.utils.cpp_extension.load(
name="flash_attn_cuda",
name="flash_attn_cuda"+CONFIG_NAME,
sources=[
OUTPUT_DIR + "/flash_profile_api.cpp", # "csrc/flash_attn/flash_api.cpp",
OUTPUT_DIR + "/flash_fwd.cu",
@@ -612,6 +612,7 @@ def forward(
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ctx.headdim_qk = q.shape[-1] # before padding
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q,
k,
@@ -657,8 +658,8 @@ def backward(ctx, dout, *args):
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
dk = dk[..., : k.shape[-1]]
dq = dq[..., : ctx.headdim_qk] # We could have padded the head dimension
dk = dk[..., : ctx.headdim_qk]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None

@@ -686,6 +687,7 @@ def forward(
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ctx.headdim_qk = q.shape[-1] # before padding
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q,
k,
@@ -744,8 +746,8 @@ def backward(ctx, dout, *args):
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
dk = dk[..., : k.shape[-1]]
dq = dq[..., : ctx.headdim_qk] # We could have padded the head dimension
dk = dk[..., : ctx.headdim_qk]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None

2 changes: 1 addition & 1 deletion autotuner/template/flash_fwd.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template_profile.h"

#define False false
#define True true
168 changes: 168 additions & 0 deletions autotuner/template/flash_fwd_launch_template_profile.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

#include <ATen/cuda/CUDAContext.h>

#include "static_switch.h"
#include "flash_profile.h"
#include "flash_fwd_kernel.h"

// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif

// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");

// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);

// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21

const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kQKHeadDim; //TODO: Check if this is correct
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kVHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;// TODO: Check if this is correct
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}

template<typename Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kQKHeadDim; //TODO: Check if this is correct
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kVHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>; // TODO: Check if this is correct
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
});
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kQKHeadDim % 128 == 0 ? 4 : (Kernel_traits::kQKHeadDim % 64 == 0 ? 8 : 16); // TODO: Check if this is correct
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}

template<typename T, int QKHeaddim, int VHeaddim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr static int kBlockN = QKHeaddim <= 64 ? 256 : (QKHeaddim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<QKHeaddim, VHeaddim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
}
196 changes: 196 additions & 0 deletions autotuner/template/flash_profile.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

#include <cuda.h>
#include <vector>

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif

#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;

// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;

// The number of heads.
int h, h_k, h_v;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
int h_h_v_ratio;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public Qkv_params {

// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;

// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;

// The pointer to the P matrix.
void * __restrict__ p_ptr;

// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, vd, seqlen_q_rounded, seqlen_k_rounded, d_rounded, vd_rounded, rotary_dim, total_q;

// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;

// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;

// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;

int *__restrict__ blockmask;

// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;

// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;

// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;

// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;

// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;

// Local window size
int window_size_left, window_size_right;
float softcap;

// Random state.
at::PhiloxCudaState philox_args;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

bool is_bf16;
bool is_causal;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;

bool is_rotary_interleaved;

int num_splits; // For split-KV version

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_bwd_params : public Flash_fwd_params {

// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;

// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;

// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;

// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;

bool deterministic;
index_t dq_accum_split_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int QKHeaddim, int VHeaddim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
// template<typename T, int QKHeaddim, int VHeaddim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

// template<typename T, int QKHeaddim, int VHeaddim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
139 changes: 85 additions & 54 deletions autotuner/template/flash_profile_api.cpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions autotuner/tunner.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@
import os
import torch

from .base_tunner import BaseTunner
from .configs.fwd_config import FlashFwdConfig
from base_tunner import BaseTunner
from configs.fwd_config import FlashFwdConfig

class FlashFwdTunner(BaseTunner):
def __init__(self, arch, torch_array: list, tempdir: str):