Skip to content

Commit c43f275

Browse files
CUDA: 4D FlashAttention support (#14628)
* CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel
1 parent ab82dc2 commit c43f275

File tree

9 files changed

+142
-101
lines changed

9 files changed

+142
-101
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
3333
const int ne13,
3434
const int ne31,
3535
const int ne32,
36+
const int ne33,
3637
const int nb31,
3738
const int nb32,
39+
const int nb33,
3840
const int nb01,
3941
const int nb02,
4042
const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
521523
template<int D, int ncols1, int ncols2> // D == head size
522524
__launch_bounds__(D, 1)
523525
static __global__ void flash_attn_stream_k_fixup(
524-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
526+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
525527
constexpr int ncols = ncols1*ncols2;
526528

527529
const int bidx0 = blockIdx.x;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
535537
const int iter_k = ne11 / FATTN_KQ_STRIDE;
536538
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
537539

538-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
539-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
540+
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
541+
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
540542

541543
const bool did_not_have_any_data = kbc0 == kbc0_stop;
542544
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
545547
return;
546548
}
547549

548-
const int channel = kbc0 / (iter_k*iter_j);
549-
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550+
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
551+
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
552+
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
550553

551554
if (jt*ncols1 + j >= ne01) {
552555
return;
553556
}
554557

555-
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
558+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
556559

557560
// Load the partial result that needs a fixup:
558561
float dst_val = 0.0f;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
571574
int bidx = bidx0 - 1;
572575
int kbc_stop = kbc0;
573576
while(true) {
574-
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
577+
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
575578
if (kbc == kbc_stop) { // Did not have any data.
576579
bidx--;
577580
kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
617620
const float2 * __restrict__ VKQ_meta,
618621
float * __restrict__ dst,
619622
const int parallel_blocks) {
620-
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
621-
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
622-
dst += D * gridDim.z*blockIdx.x;
623+
// Dimension 0: threadIdx.x
624+
// Dimension 1: blockIdx.x
625+
// Dimension 2: blockIdx.y
626+
// Dimension 3: blockIdx.z
627+
// Memory layout is permuted with [0, 2, 1, 3]
628+
629+
const int ne01 = gridDim.x;
630+
const int ne02 = gridDim.y;
631+
632+
const int col = blockIdx.x;
633+
const int head = blockIdx.y;
634+
const int sequence = blockIdx.z;
635+
636+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
637+
638+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
639+
VKQ_meta += j_dst_unrolled * parallel_blocks;
640+
dst += j_dst_unrolled * D;
623641

624642
const int tid = threadIdx.x;
625643
__builtin_assume(tid < D);
626644

627645
extern __shared__ float2 meta[];
628646
for (int i = tid; i < 2*parallel_blocks; i += D) {
629-
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
647+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
630648
}
631649

632650
__syncthreads();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
644662
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
645663
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
646664

647-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
665+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
648666
VKQ_denominator += KQ_max_scale * meta[l].y;
649667
}
650668

651-
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
669+
dst[tid] = VKQ_numerator / VKQ_denominator;
652670
}
653671

654672
[[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
705723

706724
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
707725

708-
GGML_ASSERT(Q->ne[3] == 1);
709-
710726
ggml_cuda_pool & pool = ctx.pool();
711727
cudaStream_t main_stream = ctx.stream();
712728
const int id = ggml_cuda_get_device();
@@ -853,8 +869,8 @@ void launch_fattn(
853869
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854870
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
855871
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
856-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
872+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
873+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
858874
Q->nb[1], Q->nb[2], Q->nb[3],
859875
nb11, nb12, nb13,
860876
nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
869885

870886
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
871887
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
872-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
888+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
873889
}
874890
} else if (parallel_blocks > 1) {
875891
const dim3 block_dim_combine(DV, 1, 1);
876-
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
892+
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
877893
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
878894

879895
flash_attn_combine_results<DV>

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
12241224
const int ne13,
12251225
const int ne31,
12261226
const int ne32,
1227+
const int ne33,
12271228
const int nb31,
12281229
const int nb32,
1230+
const int nb33,
12291231
const int nb01,
12301232
const int nb02,
12311233
const int nb03,
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
12741276
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
12751277

12761278
// kbc == k block continuous, current index in continuous ijk space.
1277-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1278-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1279+
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1280+
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
12791281

12801282
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
12811283
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
12851287
int kb0_start = kbc % iter_k;
12861288
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
12871289
while (kbc < kbc_stop && kb0_stop == iter_k) {
1288-
const int channel = kbc / (iter_k*iter_j);
1289-
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
1290+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1291+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1292+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
12901293

1291-
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1292-
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1294+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1295+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
12931296
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294-
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1295-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1297+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1298+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
12961299

1297-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1300+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
12981301

1299-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1302+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
13001303

13011304
const int kb0_start_kernel = kb0_start * kb_niter;
13021305
const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
13251328
return;
13261329
}
13271330

1328-
const int channel = kbc / (iter_k*iter_j);
1329-
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
1331+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1332+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1333+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
13301334

1331-
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1332-
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1335+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1336+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
13331337
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334-
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1335-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1338+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1339+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
13361340

1337-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1341+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
13381342

1339-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1343+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
13401344

13411345
const int kb0_start_kernel = kb0_start * kb_niter;
13421346
const int kb0_stop_kernel = kb0_stop * kb_niter;

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
3131
const int ne13,
3232
const int ne31,
3333
const int ne32,
34+
const int ne33,
3435
const int nb31,
3536
const int nb32,
37+
const int nb33,
3638
const int nb01,
3739
const int nb02,
3840
const int nb03,
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
6264

6365
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
6466

67+
const int sequence = blockIdx.z / ne02;
68+
const int head = blockIdx.z - sequence*ne02;
6569
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
66-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
67-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
68-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
69-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
70+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
71+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
72+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
73+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
7074

7175
const int stride_KV2 = nb11 / sizeof(half2);
7276

73-
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
77+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
7478
const half slopeh = __float2half(slopef);
7579

7680
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
255259
__syncthreads();
256260
}
257261

262+
float2 * dst2 = (float2 *) dst;
263+
258264
#pragma unroll
259265
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
260266
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
266272
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
267273
kqsum_j = warp_reduce_sum((float)kqsum_j);
268274

275+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
276+
269277
#pragma unroll
270-
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
271-
const int i0 = i00 + 2*threadIdx.x;
278+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
279+
const int i0 = i00 + threadIdx.x;
272280

273-
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
281+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
274282
if (gridDim.y == 1) {
275283
dst_val /= __half2half2(kqsum_j);
276284
}
277-
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
278-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
279-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
285+
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
280286
}
281287

282288
if (gridDim.y != 1 && threadIdx.x == 0) {
283-
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
289+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
284290
}
285291
}
286292
#else
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
290296
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
291297
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
292298
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
293-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
299+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
300+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
295301
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
296302
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
297303
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
3131
const int ne13,
3232
const int ne31,
3333
const int ne32,
34+
const int ne33,
3435
const int nb31,
3536
const int nb32,
37+
const int nb33,
3638
const int nb01,
3739
const int nb02,
3840
const int nb03,
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
7476

7577
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
7678

79+
const int sequence = blockIdx.z / ne02;
80+
const int head = blockIdx.z - sequence*ne02;
7781
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
78-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
79-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
80-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
81-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
82+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
83+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
84+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
85+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
8286

8387
const int stride_KV2 = nb11 / sizeof(half2);
8488

85-
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
89+
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
8690

8791
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
8892

@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
265269
__syncthreads();
266270
}
267271

272+
float2 * dst2 = (float2 *) dst;
273+
268274
#pragma unroll
269275
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
270276
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
276282
float kqsum_j = kqsum[j_VKQ_0/nwarps];
277283
kqsum_j = warp_reduce_sum(kqsum_j);
278284

285+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
286+
279287
#pragma unroll
280-
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
281-
const int i0 = i00 + 2*threadIdx.x;
288+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
289+
const int i0 = i00 + threadIdx.x;
282290

283-
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
291+
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
284292
if (gridDim.y == 1) {
285293
dst_val.x /= kqsum_j;
286294
dst_val.y /= kqsum_j;
287295
}
288-
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
289-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
290-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
296+
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
291297
}
292298

293299
if (gridDim.y != 1 && threadIdx.x == 0) {
294-
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
300+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
295301
}
296302
}
297303
#else

0 commit comments

Comments
 (0)