Skip to content

update asm pa for blockscale (256/128,128) #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def pa_fwd_asm(
K_QScale: Optional[torch.Tensor],
V_QScale: Optional[torch.Tensor],
out_: Optional[torch.Tensor] = None,
high_precision: Optional[int] = 1 # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
high_precision: Optional[int] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
block_shape: Optional[tuple[int,int]] = None,
) -> torch.Tensor: ...


Expand Down
3 changes: 2 additions & 1 deletion csrc/include/attention_asm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
std::optional<torch::Tensor> &K_QScale,
std::optional<torch::Tensor> &V_QScale,
std::optional<torch::Tensor> &out_,
std::optional<int> high_precision = 1);
std::optional<int> high_precision = 1,
std::optional<std::tuple<int, int>> block_shape = std::nullopt);
3 changes: 2 additions & 1 deletion csrc/include/rocm_ops.hpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
py::arg("K_QScale") = std::nullopt, \
py::arg("V_QScale") = std::nullopt, \
py::arg("out_") = std::nullopt, \
py::arg("high_precision") = 1);
py::arg("high_precision") = 1, \
py::arg("block_shape") = std::nullopt);

#define ATTENTION_CK_PYBIND \
m.def("pa_fwd_naive", &pa_fwd_naive, "pa_fwd_naive", \
Expand Down
23 changes: 21 additions & 2 deletions csrc/py_itfs_cu/asm_pa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
std::optional<torch::Tensor> &K_QScale,
std::optional<torch::Tensor> &V_QScale,
std::optional<torch::Tensor> &out_,
std::optional<int> high_precision = 1)
std::optional<int> high_precision = 1,
std::optional<std::tuple<int, int>> block_shape = std::nullopt)
{
torch::Tensor output = out_.value_or(torch::empty_like(Q));
int batch = context_lens.size(0);
Expand Down Expand Up @@ -106,7 +107,25 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
AiterAsmKernel *impl_ptr = nullptr;
if (K_QScale)
{
if (Q.dtype() == at::ScalarType::Half)
if (block_shape.has_value())
{
if (block_shape.value() == std::make_tuple(128, 128) && Q.dtype() == at::ScalarType::BFloat16 && K.dtype() == at::ScalarType::Float8_e4m3fnuz)
{
static AiterAsmKernel impl_a16w16_b16_f8_blockscale128("pa_a16w8_2tg_g8_f8_kv128_bf16", "pa_a16w8_2tg_g8_f8_kv128_bf16.co");
impl_ptr = &impl_a16w16_b16_f8_blockscale128;
}
else if (block_shape.value() == std::make_tuple(256, 128) && Q.dtype() == at::ScalarType::BFloat16 && K.dtype() == at::ScalarType::Float8_e4m3fnuz)
{
static AiterAsmKernel impl_a16w16_b16_f8_blockscale256("pa_a16w8_2tg_g8_f8_kv256_bf16", "pa_a16w8_2tg_g8_f8_kv256_bf16.co");
impl_ptr = &impl_a16w16_b16_f8_blockscale256;
}
else
{
TORCH_CHECK(false,
__func__, ": only support block_shape == (128, 128) | (256, 128), Q dtype == BFloat16 and quantType == fp8 for now");
}
}
else if (Q.dtype() == at::ScalarType::Half)
{
if (K.dtype() == at::ScalarType::Byte || K.dtype() == at::ScalarType::Char)
{
Expand Down
Binary file added hsa/pa_a16w8_2tg_g8_f8_kv128_bf16.co
Binary file not shown.
Binary file added hsa/pa_a16w8_2tg_g8_f8_kv256_bf16.co
Binary file not shown.
129 changes: 98 additions & 31 deletions op_tests/test_pa.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import aiter
from aiter import paged_attn as ops
from aiter.test_common import checkAllclose, perftest, tensor_dump, tensor_load
from aiter.test_common import checkAllclose, perftest, tensor_dump, tensor_load, benchmark
from aiter import pertoken_quant

uniform_range = (-1, 1)
Expand All @@ -28,6 +28,7 @@
# // same as 8bit per token quant but 4 bit
'KV_4BIT_PER_TOKEN',
'KV_8BIT_PER_TENSOR',
'KV_8BIT_PER_BLOCK',
]


Expand Down Expand Up @@ -376,7 +377,8 @@ def run_aiter_asm(query,
max_num_blocks,
k_scale=None,
v_scale=None,
high_precision=0):
high_precision=0,
block_shape=None):
return aiter.pa_fwd_asm(
query,
k_cache,
Expand All @@ -387,7 +389,8 @@ def run_aiter_asm(query,
k_scale,
v_scale,
None,
high_precision
high_precision,
block_shape
)


Expand Down Expand Up @@ -461,7 +464,7 @@ def asm_V_shuffle(VC):
VC = VC.permute(0, 1, 3, 2, 4).contiguous()
return VC


@benchmark()
def test_paged_attention(
ctx_lens: int,
num_seqs: int,
Expand All @@ -472,7 +475,8 @@ def test_paged_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str
device: str,
block_shape: Optional[Tuple[int, int]] = None,
) -> None:
torch.set_default_device(device)
# Using default kv_scale
Expand All @@ -486,6 +490,8 @@ def test_paged_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
max_seq_len = ctx_lens
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
if block_shape is not None:
max_num_blocks_per_seq = ((max_seq_len + block_shape[0] - 1) // block_shape[0]) * block_shape[0] // block_size
num_blocks = max_num_blocks_per_seq*num_seqs
print(f'{debug_mode=}')

Expand All @@ -510,10 +516,17 @@ def test_paged_attention(
# Create the block tables.
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
if block_shape is None:
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
else:
serial_block_num = block_shape[0] // block_size
random_block = (random.randint(0, num_blocks - max_num_blocks_per_seq) //serial_block_num) * serial_block_num
block_table = list(
range(random_block, random_block + max_num_blocks_per_seq)
)
block_tables_lst.append(block_table)

block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
Expand Down Expand Up @@ -569,7 +582,8 @@ def test_paged_attention(
(2, torch.float8_e4m3fnuz),
(2, torch.int8),
(4, torch.float8_e4m3fnuz),
]:
] if block_shape is None else \
[(5, torch.float8_e4m3fnuz)]:
quant_algo = ck_naive_quant_algo[quant_algo_]
if quant_algo == "NO":
k_quant_, k_scale_, v_quant_, v_scale_ = k_cache, torch.empty(
Expand Down Expand Up @@ -634,6 +648,47 @@ def test_paged_attention(
# )
# checkAllclose(out_aiter_asm, out_aiter_naive,
# msg=f'golden vs ck_naive(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}):{time_aiter_naive:.2f} us......')
elif quant_algo == "KV_8BIT_PER_BLOCK":
assert block_shape in [(128, 128), (256, 128)], "KV_8BIT_PER_BLOCK only supports block_shape (128, 128) or (256, 128)"
assert head_size == block_shape[1], "KV_8BIT_PER_BLOCK only supports head_size == block_shape[1]"
assert num_blocks % (block_shape[0] // block_size) == 0, "KV_8BIT_PER_BLOCK only supports num_blocks multiple of (block_shape[0] // block_size)"

x = k_cache.shape[-1]
k_cache_permute = k_cache.view(num_blocks // (block_shape[0] // block_size), (block_shape[0] // block_size), num_kv_heads, head_size//x, block_size, x).permute(0, 2, 1, 3, 4, 5).contiguous()
k_quant_, k_scale_asm = pertoken_quant(k_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1), quant_dtype=torch.float8_e4m3fnuz)
k_cache_permute = (k_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1).to(torch.float) * k_scale_asm.to(torch.float)).to(k_cache.dtype)
k_cache = k_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size//x, block_size, x).permute(0, 2, 1, 3, 4, 5).contiguous()
del k_cache_permute
k_cache = k_cache.view(num_blocks, num_kv_heads, head_size//x, block_size, x)
k_quant_ = k_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size//x, block_size, x).permute(0, 2, 1, 3, 5, 4).contiguous()
x = 16 // torch.float8_e4m3fnuz.itemsize
k_quant_ = k_quant_.view(num_blocks, num_kv_heads, head_size//x, x, block_size).permute(0, 1, 2, 4, 3).contiguous()
k_scale_asm = k_scale_asm.view(num_blocks // (block_shape[0] // block_size), num_kv_heads).permute(1, 0).contiguous()

v_cache_permute = v_cache.view(num_blocks // (block_shape[0] // block_size), (block_shape[0] // block_size), num_kv_heads, head_size, block_size).permute(0, 2, 1, 3, 4).contiguous()
v_quant_, v_scale_asm = pertoken_quant(v_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1), quant_dtype=torch.float8_e4m3fnuz)
v_cache_permute = (v_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1).to(torch.float) * v_scale_asm.to(torch.float)).to(v_cache.dtype)
v_cache = v_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size, block_size).permute(0, 2, 1, 3, 4).contiguous()
del v_cache_permute
v_cache = v_cache.view(num_blocks, num_kv_heads, head_size, block_size)
v_quant_ = v_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size, block_size).permute(0, 2, 1, 3, 4).contiguous()
v_quant_ = v_quant_.view(num_blocks, num_kv_heads, head_size, block_size)
v_scale_asm = v_scale_asm.view(num_blocks // (block_shape[0] // block_size), num_kv_heads).permute(1, 0).contiguous()

out_golden, time_aiter = run_aiter(
query,
k_cache,
v_cache,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
alibi_slopes,
k_scale,
v_scale,
)

if quant_algo_ != 0:
out_aiter_asm, time_aiter_asm = run_aiter_asm(
Expand All @@ -650,9 +705,10 @@ def test_paged_attention(
max_num_blocks_per_seq,
k_scale_asm,
v_scale_asm,
block_shape=block_shape
)
checkAllclose(out_golden, out_aiter_asm,
msg=f'golden vs aiter_asm:{time_aiter_asm:.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})')
msg=f'golden vs aiter_asm:{time_aiter_asm:.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}, {block_shape=})')

if dtype in [torch.bfloat16, torch.float16] and quant_algo_ == 2 and cache_type_ == torch.float8_e4m3fnuz:
if dtype == torch.bfloat16:
Expand Down Expand Up @@ -682,26 +738,27 @@ def test_paged_attention(
# if quant_algo == "KV_8BIT_PER_TENSOR":
# q_quant_, q_scale_ = aiter.per_tensor_quant(
# query, quant_dtype=cache_type_)
out_native, time_native = run_native(
query,
# q_quant_,
k_quant_,
v_quant_,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
# scale*q_scale_.item(),
alibi_slopes,
k_scale_,
v_scale_,
num_queries_per_kv,
dtype
)
checkAllclose(
out_golden, out_native, msg=f'golden vs torch_native: {time_native:.2f} us...... (quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})')
if quant_algo != "KV_8BIT_PER_BLOCK":
out_native, time_native = run_native(
query,
# q_quant_,
k_quant_,
v_quant_,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
# scale*q_scale_.item(),
alibi_slopes,
k_scale_,
v_scale_,
num_queries_per_kv,
dtype
)
checkAllclose(
out_golden, out_native, msg=f'golden vs torch_native: {time_native:.2f} us...... (quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})')

if debug_mode == DUMP:
dump_input(query,
Expand Down Expand Up @@ -751,3 +808,13 @@ def test_paged_attention(
for dtype in [torch.float16, torch.bfloat16]:
test_paged_attention(ctx_len, 128, num_heads, 128, False, 16,
dtype, "auto", 0, "cuda:0")

for num_heads in [(4, 1), (8, 1), (32, 8)]:
for ctx_len in [7, 26, 57, 66, 109, 128, 257, 282, 4097]:
for dtype in [torch.bfloat16]:
test_paged_attention(ctx_len, 128, num_heads, 128, False, 16, torch.bfloat16, "auto", 0, "cuda:0", block_shape=(128,128))

for num_heads in [(4, 1), (8, 1), (32, 8)]:
for ctx_len in [7, 26, 57, 66, 109, 128, 257, 282, 4097]:
for dtype in [torch.bfloat16]:
test_paged_attention(ctx_len, 128, num_heads, 128, False, 16, torch.bfloat16, "auto", 0, "cuda:0", block_shape=(256,128))