diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8114cb42046e4..5fcea51edf790 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -634,6 +634,7 @@ struct vk_flash_attn_push_constants { uint32_t nev3; uint32_t nem1; uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -649,8 +650,7 @@ struct vk_flash_attn_push_constants { float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -6050,6 +6050,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0; + const uint32_t nem3 = mask ? mask->ne[3] : 0; const uint32_t D = neq0; uint32_t N = neq1; @@ -6119,7 +6120,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. @@ -6311,17 +6312,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } + uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2; + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, - nem1, nem2, + nem1, nem2, nem3, q_stride, (uint32_t)nbq2, (uint32_t)nbq3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1, + mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; ggml_vk_sync_buffers(subctx); @@ -10265,12 +10268,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // TODO: support broadcast - // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but - // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 - if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) { - return false; - } // It's straightforward to support different K/V dequant, but would // significantly increase the number of pipelines if (op->src[1]->type != op->src[2]->type) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6f80101d1c432..453b08c2ae389 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -100,8 +100,8 @@ void main() { uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif uint32_t m_offset = 0; - if (p.nem2 != 1) { - m_offset = (iq3 % p.nem2) * p.nem1 * KV; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -148,7 +148,7 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 67b935cf57e50..026f7d9331d27 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint32_t nev3; uint32_t nem1; uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -40,8 +41,7 @@ layout (push_constant) uniform parameter { float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -50,6 +50,9 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) @@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i { const uint32_t h = iq2 + (r % p.gqa_ratio); - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); return ACC_TYPE(pow(base, ACC_TYPE(exph))); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 26fe50c7a81e5..5722f006c7f19 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -124,8 +124,8 @@ void main() { uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif uint32_t m_offset = 0; - if (p.nem2 != 1) { - m_offset = (iq3 % p.nem2) * p.nem1 * KV; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -180,7 +180,7 @@ void main() { barrier(); } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index cf47bd53e28bb..677ac405154bb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -131,8 +131,8 @@ void main() { } uint32_t m_offset = 0; - if (p.nem2 != 1) { - m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; } [[dont_unroll]] @@ -153,7 +153,7 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);