Skip to content

Commit 2f01628

Browse files
jeffbolznvMinh141120
authored andcommitted
vulkan: Handle updated FA dim2/3 definition (ggml-org#14518)
* vulkan: Handle updated FA dim2/3 definition Pack mask boolean and n_head_log2 into a single dword to keep the push constant block under the 128B limit. * handle null mask for gqa * allow gqa with dim3>1
1 parent 344c1ce commit 2f01628

File tree

5 files changed

+26
-24
lines changed

5 files changed

+26
-24
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ struct vk_flash_attn_push_constants {
636636
uint32_t nev3;
637637
uint32_t nem1;
638638
uint32_t nem2;
639+
uint32_t nem3;
639640

640641
uint32_t nb01;
641642
uint32_t nb02;
@@ -651,8 +652,7 @@ struct vk_flash_attn_push_constants {
651652
float max_bias;
652653
float logit_softcap;
653654

654-
uint32_t mask;
655-
uint32_t n_head_log2;
655+
uint32_t mask_n_head_log2;
656656
float m0;
657657
float m1;
658658

@@ -6114,6 +6114,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61146114

61156115
const uint32_t nem1 = mask ? mask->ne[1] : 0;
61166116
const uint32_t nem2 = mask ? mask->ne[2] : 0;
6117+
const uint32_t nem3 = mask ? mask->ne[3] : 0;
61176118

61186119
const uint32_t HSK = nek0;
61196120
const uint32_t HSV = nev0;
@@ -6181,7 +6182,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61816182
}
61826183

61836184
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6184-
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6185+
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
61856186
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
61866187
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
61876188
// and change addressing calculations to index Q's dimension 2.
@@ -6351,17 +6352,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63516352
}
63526353
}
63536354

6355+
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6356+
63546357
const vk_flash_attn_push_constants pc = { N, KV,
63556358
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
63566359
(uint32_t)neq2, (uint32_t)neq3,
63576360
(uint32_t)nek2, (uint32_t)nek3,
63586361
(uint32_t)nev2, (uint32_t)nev3,
6359-
nem1, nem2,
6362+
nem1, nem2, nem3,
63606363
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63616364
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63626365
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
63636366
scale, max_bias, logit_softcap,
6364-
mask != nullptr, n_head_log2, m0, m1,
6367+
mask_n_head_log2, m0, m1,
63656368
gqa_ratio, split_kv, split_k };
63666369

63676370
ggml_vk_sync_buffers(subctx);
@@ -10306,12 +10309,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1030610309
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1030710310
return false;
1030810311
}
10309-
// TODO: support broadcast
10310-
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10311-
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10312-
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10313-
return false;
10314-
}
1031510312
// It's straightforward to support different K/V dequant, but would
1031610313
// significantly increase the number of pipelines
1031710314
if (op->src[1]->type != op->src[2]->type) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ void main() {
101101
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
102102
#endif
103103
uint32_t m_offset = 0;
104-
if (p.nem2 != 1) {
105-
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
104+
if (p.nem2 != 1 || p.nem3 != 1) {
105+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
106106
}
107107

108108
[[dont_unroll]]
@@ -149,7 +149,7 @@ void main() {
149149
}
150150
}
151151

152-
if (p.mask != 0) {
152+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
153153

154154
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
155155
uint32_t c = (idx + tid) % Bc;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
2525
uint32_t nev3;
2626
uint32_t nem1;
2727
uint32_t nem2;
28+
uint32_t nem3;
2829

2930
uint32_t nb01;
3031
uint32_t nb02;
@@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
4041
float max_bias;
4142
float logit_softcap;
4243

43-
uint32_t mask;
44-
uint32_t n_head_log2;
44+
uint32_t mask_n_head_log2;
4545
float m0;
4646
float m1;
4747

@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
5050
uint32_t k_num;
5151
} p;
5252

53+
#define MASK_ENABLE_BIT (1<<16)
54+
#define N_LOG2_MASK 0xFFFF
55+
5356
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
5457

5558
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
100103
{
101104
const uint32_t h = iq2 + (r % p.gqa_ratio);
102105

103-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
106+
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107+
108+
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109+
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
105110

106111
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107112
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ void main() {
126126
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
127127
#endif
128128
uint32_t m_offset = 0;
129-
if (p.nem2 != 1) {
130-
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
129+
if (p.nem2 != 1 || p.nem3 != 1) {
130+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
131131
}
132132

133133
[[dont_unroll]]
@@ -182,7 +182,7 @@ void main() {
182182
barrier();
183183
}
184184

185-
if (p.mask != 0) {
185+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
186186
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
187187
uint32_t c = (idx + tid) % Bc;
188188
uint32_t r = (idx + tid) / Bc;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ void main() {
131131
}
132132

133133
uint32_t m_offset = 0;
134-
if (p.nem2 != 1) {
135-
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
134+
if (p.nem2 != 1 || p.nem3 != 1) {
135+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
136136
}
137137

138138
[[dont_unroll]]
@@ -153,7 +153,7 @@ void main() {
153153
}
154154
}
155155

156-
if (p.mask != 0) {
156+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
157157
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
158158
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
159159
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);

0 commit comments

Comments
 (0)