Skip to content

vulkan: Handle updated FA dim2/3 definition #14518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ struct vk_flash_attn_push_constants {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;

uint32_t nb01;
uint32_t nb02;
Expand All @@ -649,8 +650,7 @@ struct vk_flash_attn_push_constants {
float max_bias;
float logit_softcap;

uint32_t mask;
uint32_t n_head_log2;
uint32_t mask_n_head_log2;
float m0;
float m1;

Expand Down Expand Up @@ -6050,6 +6050,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx

const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t nem3 = mask ? mask->ne[3] : 0;

const uint32_t D = neq0;
uint32_t N = neq1;
Expand Down Expand Up @@ -6119,7 +6120,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}

if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
Expand Down Expand Up @@ -6311,17 +6312,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}

uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;

const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1, nem2,
nem1, nem2, nem3,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1,
mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };

ggml_vk_sync_buffers(subctx);
Expand Down Expand Up @@ -10265,12 +10268,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
return false;
}
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}

[[dont_unroll]]
Expand Down Expand Up @@ -148,7 +148,7 @@ void main() {
}
}

if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {

[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
Expand Down
13 changes: 9 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;

uint32_t nb01;
uint32_t nb02;
Expand All @@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
float max_bias;
float logit_softcap;

uint32_t mask;
uint32_t n_head_log2;
uint32_t mask_n_head_log2;
float m0;
float m1;

Expand All @@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;

#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF

layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};

#if defined(A_TYPE_PACKED16)
Expand Down Expand Up @@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
{
const uint32_t h = iq2 + (r % p.gqa_ratio);

const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;

const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);

return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}

[[dont_unroll]]
Expand Down Expand Up @@ -180,7 +180,7 @@ void main() {
barrier();
}

if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ void main() {
}

uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}

[[dont_unroll]]
Expand All @@ -153,7 +153,7 @@ void main() {
}
}

if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
Expand Down
Loading