From 2cb47e0e1673f05db2fcf53ad56a8dbbb3332139 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 13:21:50 +0300 Subject: [PATCH 1/7] Very minor speedup via simd-group synchronization in f16 x f32 --- ggml-metal.m | 2 +- ggml-metal.metal | 45 ++++++++------------------------------------- 2 files changed, 9 insertions(+), 38 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4267db9be3e61..f9a2228aa0201 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -971,7 +971,7 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; + //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } diff --git a/ggml-metal.metal b/ggml-metal.metal index 8cdf0b9d2ba0a..aeb33c581100d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -515,11 +515,8 @@ kernel void kernel_mul_mat_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpig[[thread_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -528,42 +525,16 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - uint ith = tpitg.x; - uint nth = tptg.x; - - sum[ith] = 0.0f; - - for (int i = ith; i < ne00; i += nth) { - sum[ith] += (float) x[i] * (float) y[i]; + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; } - // accumulate the sum from all threads in the threadgroup - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; - } - - // Original implementation. Left behind commented out for now - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = tptg.x/2; i > 0; i /= 2) { - // if (tpitg.x < i) { - // sum[tpitg.x] += sum[tpitg.x + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - // - //if (tpitg.x == 0) { - // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; - //} + } kernel void kernel_alibi_f32( From e3ff8c20c89390285db472b78caf00ca8c178872 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 14:54:10 +0300 Subject: [PATCH 2/7] Another very minor speedup on metal --- ggml-metal.metal | 66 +++++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index aeb33c581100d..171b0bcf989eb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -133,19 +133,24 @@ kernel void kernel_soft_max( threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg[0] == 0) { - buf[0] = buf[0]; - } + //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of + // the loop, and when that is done, buf[0] has the correct (synchronized) value + //if (tpitg[0] == 0) { + // buf[0] = buf[0]; + //} - threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup_barrier(mem_flags::mem_threadgroup); const float max = buf[0]; // parallel sum buf[tpitg[0]] = 0.0f; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - buf[tpitg[0]] += exp(psrc0[i00] - max); + const float exp_psrc0 = exp(psrc0[i00] - max); + buf[tpitg[0]] += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // whish to compute it twice. + pdst[i00] = exp_psrc0; } // reduce @@ -157,17 +162,18 @@ kernel void kernel_soft_max( threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg[0] == 0) { - buf[0] = buf[0]; - } + // broadcast - not needed, see above + //// broadcast + //if (tpitg[0] == 0) { + // buf[0] = buf[0]; + //} - threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup_barrier(mem_flags::mem_threadgroup); const float sum = buf[0]; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - pdst[i00] = exp(psrc0[i00] - max) / sum; + pdst[i00] /= sum; } } @@ -214,25 +220,27 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg == 0) { - sum[0] /= ne00; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + //// broadcast + //if (tpitg == 0) { + // sum[0] /= ne00; + //} + //threadgroup_barrier(mem_flags::mem_threadgroup); const float mean = sum[0]; - // recenter + // recenter and VARIANCE device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - } - - // VARIANCE - // parallel sum sum[tpitg] = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; sum[tpitg] += y[i00] * y[i00]; } + + //// VARIANCE + //// parallel sum + //sum[tpitg] = 0.0f; + //for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + // sum[tpitg] += y[i00] * y[i00]; + //} // reduce threadgroup_barrier(mem_flags::mem_threadgroup); for (uint i = ntg/2; i > 0; i /= 2) { @@ -241,11 +249,11 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg == 0) { - sum[0] /= ne00; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + //// broadcast + //if (tpitg == 0) { + // sum[0] /= ne00; + //} + //threadgroup_barrier(mem_flags::mem_threadgroup); const float variance = sum[0]; const float scale = 1.0f/sqrt(variance + eps); From 2b601702a80e7b30fadc6c886632fb8638111ba9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 17:06:53 +0300 Subject: [PATCH 3/7] Quite significant PP speedup on metal --- ggml-metal.m | 11 +++++++---- ggml-metal.metal | 33 ++++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f9a2228aa0201..d365ddc6734a6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -906,8 +906,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 2; - nth1 = 32; + nth0 = 4; //1; + nth1 = 8; //32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: @@ -955,9 +955,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -972,7 +975,7 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne11 + 3)/4, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 171b0bcf989eb..caf8d37a435bc 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -505,6 +505,8 @@ kernel void kernel_mul_mat_q8_0_f32( } } +#define N_F16_F32 4 + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, @@ -527,20 +529,28 @@ kernel void kernel_mul_mat_f16_f32( uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int64_t rb = N_F16_F32*tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } } @@ -1241,7 +1251,8 @@ kernel void kernel_mul_mat_q4_K_f32( const int r0 = tgpig.x; const int r1 = tgpig.y; const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; const int ib_row = first_row * nb; const uint offset0 = r2/gqa*(nb*ne0); device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; From b557bc326db194931d28cd6ecc804e35bde9d053 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 17:50:21 +0300 Subject: [PATCH 4/7] Another attempt --- ggml-metal.m | 13 ++++++++++--- ggml-metal.metal | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d365ddc6734a6..c27f73a0203d4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -76,6 +76,7 @@ GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); @@ -205,6 +206,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); @@ -270,6 +272,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); @@ -854,7 +857,11 @@ void ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + if (ne11 * ne12 < 2) { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } } break; case GGML_TYPE_Q4_0: { @@ -974,8 +981,8 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne11 + 3)/4, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + int64_t ny = (ne11 + 3)/4; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index caf8d37a435bc..62b222aa405af 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -505,6 +505,45 @@ kernel void kernel_mul_mat_q8_0_f32( } } +kernel void kernel_mul_mat_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } +} + #define N_F16_F32 4 kernel void kernel_mul_mat_f16_f32( From 74df0de9e60c493abda1c5c99624b0723effd01a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 18:15:45 +0300 Subject: [PATCH 5/7] Minor --- ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index c27f73a0203d4..84614537d2dc5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -857,7 +857,7 @@ void ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - if (ne11 * ne12 < 2) { + if (ne11 * ne12 < 4) { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; } else { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; From 363f0bf5580d9d59aaabfb227f6595f4806b32ab Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 2 Sep 2023 18:14:41 +0300 Subject: [PATCH 6/7] Massive improvement for TG for fp16 --- ggml-metal.metal | 73 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 62b222aa405af..e2eb5ba358fdb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -534,14 +534,27 @@ kernel void kernel_mul_mat_f16_f32_1row( device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } } #define N_F16_F32 4 @@ -573,22 +586,46 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } } From 6af0bab347574fcb18081bda12a5d9de04dfe367 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Sep 2023 09:00:27 +0300 Subject: [PATCH 7/7] ~4-5% improvement for Q8_0 TG on metal --- ggml-metal.metal | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e2eb5ba358fdb..3fa311b4027f9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -443,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +#define NB_Q8_0 8 + kernel void kernel_mul_mat_q8_0_f32( device const void * src0, device const float * src1, @@ -471,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32( device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; + float yl[NB_Q8_0]; float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = tiisg%2; + const int ix = tiisg/4; + const int il = tiisg%4; - device const float * yb = y + ix * QK8_0 + 16*il; + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - for (int i = 0; i < 16; ++i) { + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + 16*il; + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; float sumq = 0.f; - for (int iq = 0; iq < 16; ++iq) { + for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*x[ib+row*nb].d; } - yb += QK8_0 * 16; + yb += NB_Q8_0 * nw; } for (int row = 0; row < nr; ++row) {