diff --git a/common/arg.cpp b/common/arg.cpp index 56827a65908be..4a0e6cbd68cb2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.swa_full = true; } ).set_env("LLAMA_ARG_SWA_FULL")); + add_opt(common_arg( + {"--kv-split", "-kvs"}, + string_format("use multiple streams when computing the attention (default: %s)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_split ? "true" : "false"), + [](common_params & params) { + params.kv_split = true; + } + ).set_env("LLAMA_ARG_KV_SPLIT")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), diff --git a/common/common.cpp b/common/common.cpp index e4e71ad13fb59..51b548cafd6e2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.no_perf = params.no_perf; cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; + cparams.kv_unified = !params.kv_split; cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; diff --git a/common/common.h b/common/common.h index a5abe32859fdd..a65a2fe99c9af 100644 --- a/common/common.h +++ b/common/common.h @@ -330,6 +330,7 @@ struct common_params { bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + bool kv_split = false; // disable unified KV cache bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool use_mmap = true; // use mmap for faster loads diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 0ec2999a0c8e9..40ff6483807ee 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -107,7 +107,7 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model); const int n_ctx_train = llama_model_n_ctx_train(model); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index d53e089a4cbc2..46fb451baa712 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -224,6 +224,7 @@ int main(int argc, char ** argv) { auto & client = clients[i]; client.id = i; client.smpl = common_sampler_init(model, params.sampling); + //params.sampling.seed++; } std::vector tokens_system; @@ -345,7 +346,7 @@ int main(int argc, char ** argv) { client.n_decoded = 0; client.i_batch = batch.n_tokens - 1; - LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur); + LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt); g_seq_id += 1; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 075f14a49e9ac..9122fca6cf99f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int channel = kbc0 / (iter_k*iter_j); - const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. if (jt*ncols1 + j >= ne01) { return; } - dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { - VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; - dst += D * gridDim.z*blockIdx.x; + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + const int ne01 = gridDim.x; + const int ne02 = gridDim.y; + + const int col = blockIdx.x; + const int head = blockIdx.y; + const int sequence = blockIdx.z; + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; const int tid = threadIdx.x; __builtin_assume(tid < D); extern __shared__ float2 meta[]; for (int i = tid; i < 2*parallel_blocks; i += D) { - ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } __syncthreads(); @@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; + dst[tid] = VKQ_numerator / VKQ_denominator; } [[noreturn]] @@ -705,8 +723,6 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - GGML_ASSERT(Q->ne[3] == 1); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -853,8 +869,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, @@ -869,11 +885,11 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); flash_attn_combine_results diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 709589854f0af..6fa2e77299eb0 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16( constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16( int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16( return; } - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 0c967f178e7b1..1f141328845a4 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16( half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum((float)kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val /= __half2half2(kqsum_j); } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val); + dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else @@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 908c76dbdd270..a4965583cef1c 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32( float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val.x /= kqsum_j; dst_val.y /= kqsum_j; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y; + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e78fb181919fd..b2d469938abf2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index b2f1724c95588..405b6f5106ea0 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index c95ca7b1f285f..741b8781d29f5 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); @@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16( if (ic0 + j_VKQ >= ne01) { return; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; float KQ_rowsum_j; if (std::is_same::value) { @@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; @@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16( if (gridDim.y == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; + dst[j_dst_unrolled*D + i] = dst_val; } if (gridDim.y == 1 || threadIdx.x != 0) { @@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16( dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } dst_meta_val.y = KQ_rowsum_j; - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; + dst_meta[j_dst_unrolled] = dst_meta_val; } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); + GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 88b17dd682c95..588b575197a4d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3408,12 +3408,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } - // TODO: support broadcast - // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but - // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 - if (op->src[0]->ne[3] != 1) { - return false; - } if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } @@ -3426,6 +3420,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } + if (op->src[3] && op->src[3]->ne[2] != 1) { + return false; + } return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; } diff --git a/include/llama.h b/include/llama.h index f73b1ab65fe6f..bc6bdd92e7419 100644 --- a/include/llama.h +++ b/include/llama.h @@ -334,6 +334,9 @@ extern "C" { bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 + bool kv_unified; // use a unified buffer across the input sequences when computing the attention + // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 }; // model quantization parameters diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 3bc8554e51ccf..f8227777f19de 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -27,6 +27,7 @@ bool llama_batch_allocr::init( const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all) { clear(); @@ -40,6 +41,11 @@ bool llama_batch_allocr::init( // validate input batch // + if (n_seq_max > LLAMA_MAX_SEQ) { + LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ); + return false; + } + if (batch.token) { for (int32_t i = 0; i < batch.n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { @@ -52,8 +58,8 @@ bool llama_batch_allocr::init( if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max); return false; } } @@ -86,7 +92,7 @@ bool llama_batch_allocr::init( // initialize the starting position for each sequence based on the positions in the memory llama_pos p0[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (!memory) { // if no memory -> start from 0 p0[s] = 0; @@ -143,7 +149,8 @@ bool llama_batch_allocr::init( // compute stats // - this->n_embd = n_embd; + this->n_embd = n_embd; + this->n_seq_max = n_seq_max; // count the outputs in this batch for (int32_t i = 0; i < batch.n_tokens; ++i) { @@ -189,7 +196,7 @@ bool llama_batch_allocr::init( seq_set_map[cur].push_back(i); } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { seq_idx[s] = seq_id_unq.size(); seq_id_unq.push_back(s); @@ -241,7 +248,7 @@ bool llama_batch_allocr::init( // consistency checks // - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_pos[s].empty()) { continue; } @@ -284,8 +291,8 @@ bool llama_batch_allocr::init( } if (memory) { - for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { - for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { + for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) { + for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) { if (seq_cpl[s0][s1]) { if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { @@ -316,12 +323,12 @@ bool llama_batch_allocr::init( // { seq_set_t cur_seq_set[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_set[s].set(); } llama_pos cur_seq_pos[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_pos[s] = -1; } @@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { ubatch.seq_idx[s] = ubatch.seq_id_unq.size(); ubatch.seq_id_unq.push_back(s); diff --git a/src/llama-batch.h b/src/llama-batch.h index 3420803ff9469..1a24440ba7562 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -48,6 +48,7 @@ class llama_batch_allocr { const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all); const llama_batch & get_batch() const; @@ -100,6 +101,7 @@ class llama_batch_allocr { const uint32_t n_pos_per_embd; uint32_t n_embd; + uint32_t n_seq_max; uint32_t n_outputs; std::array seq_id_0 = { 0 }; // default sequence id diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cbf40..2ad2419fbd50f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -102,6 +102,7 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.op_offload = params.op_offload; + cparams.kv_unified = params.kv_unified; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -112,6 +113,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -267,7 +269,7 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); @@ -300,7 +302,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, mctx.get()); + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -311,6 +313,10 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -475,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -734,13 +740,15 @@ int llama_context::encode(const llama_batch & batch_inp) { const int64_t n_embd = hparams.n_embd; // note: during encode, we always pass the full sequence starting from pos = 0 - if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { + if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } const uint32_t n_tokens = balloc->get_n_tokens(); + // [TAG_NO_CACHE_PAD] + // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true const llama_ubatch ubatch = balloc->split_simple(n_tokens); // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot @@ -899,7 +907,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -2028,7 +2036,7 @@ void llama_context::opt_epoch_iter( batch.logits [pos_batch] = true; } - if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) { + if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return; } @@ -2187,6 +2195,7 @@ llama_context_params llama_context_default_params() { /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, + /*.kv_unified =*/ true, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 118615d5bd2d5..38750affc500b 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -11,8 +11,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -33,6 +33,7 @@ struct llama_cparams { bool no_perf; bool warmup; bool op_offload; + bool kv_unified; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a248a7ec22350..1a6355e85d11e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha( float kq_scale) const { const bool v_trans = v->nb[1] > v->nb[2]; + // split the batch into streams if needed + const auto n_stream = k->ne[3]; + + q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_kv = k->ne[1]; ggml_tensor * cur; @@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( #endif } - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + // recombine streams + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); + // [TAG_NO_CACHE_PAD] + // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams + assert(ubatch.equal_seqs == false); + ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; @@ -1156,13 +1164,14 @@ static std::unique_ptr build_attn_inp_kv_unifie { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif auto inp = std::make_unique(hparams, cparams, mctx_cur); + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + { const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; diff --git a/src/llama-graph.h b/src/llama-graph.h index fbf8e2889564d..84a5b0b3f9c40 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -255,10 +255,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams & hparams; const llama_cparams & cparams; @@ -289,14 +289,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams & hparams; const llama_cparams & cparams; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 7aa736e2f39db..c6c67d26f9392 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } +bool llama_hparams::is_n_embd_k_gqa_variable() const { + const uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_k_gqa(il)) { + return true; + } + } + + return false; +} + +bool llama_hparams::is_n_embd_v_gqa_variable() const { + const uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_v_gqa(il)) { + return true; + } + } + + return false; +} + +uint32_t llama_hparams::n_embd_k_gqa_max() const { + uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_k_gqa(il)); + } + + return val; +} + +uint32_t llama_hparams::n_embd_v_gqa_max() const { + uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_v_gqa(il)); + } + + return val; +} + uint32_t llama_hparams::n_embd_r() const { if (wkv_head_size != 0) { // for RWKV models diff --git a/src/llama-hparams.h b/src/llama-hparams.h index d0500e4d0fd77..a9d86ca6fd6e5 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -191,6 +191,14 @@ struct llama_hparams { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa(uint32_t il = 0) const; + // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa + bool is_n_embd_k_gqa_variable() const; + bool is_n_embd_v_gqa_variable() const; + + // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers + uint32_t n_embd_k_gqa_max() const; + uint32_t n_embd_v_gqa_max() const; + // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size uint32_t n_embd_r() const; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index fe207ad536032..01d27fb4db9b1 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams), unified(unified) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { @@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, + v_trans, offload, unified, size_base, n_seq_max, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, + v_trans, offload, unified, size_swa, n_seq_max, n_pad, hparams.n_swa, hparams.swa_type); } @@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { + if (!unified) { + // requires equal splits, so we skip the simple split + break; + } + balloc.split_reset(); std::vector ubatches; @@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch, false); + auto ubatch = balloc.split_equal(n_ubatch, !unified); if (ubatch.n_tokens == 0) { break; diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 23205d826b23b..d2650dadd3595 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -20,6 +20,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, @@ -68,6 +69,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { private: const llama_hparams & hparams; + const bool unified; + std::unique_ptr kv_base; std::unique_ptr kv_swa; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index d3129cc53281e..3b41ddb1f6bb5 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - head = 0; + GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); - cells.resize(kv_size); + v_heads.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_heads[s] = 0; + } + + v_cells.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].resize(kv_size); + } + + // by default, all sequence ids are mapped to the 0th stream + seq_to_stream.resize(LLAMA_MAX_SEQ, 0); + + if (n_stream > 1) { + seq_to_stream.resize(n_stream, 0); + for (uint32_t s = 0; s < n_stream; ++s) { + seq_to_stream[s] = s; + } + } + + // [TAG_V_CACHE_VARIABLE] + if (v_trans && hparams.is_n_embd_v_gqa_variable()) { + LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n", + __func__, hparams.n_embd_v_gqa_max()); + } for (uint32_t il = 0; il < n_layer_cache; il++) { if (filter && !filter(il)) { @@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + // [TAG_V_CACHE_VARIABLE] + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); const char * dev_name = "CPU"; @@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); + std::vector k_stream; + std::vector v_stream; + + for (uint32_t s = 0; s < n_stream; ++s) { + k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); + v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + } + map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); + + layers.push_back({ il, k, v, k_stream, v_stream, }); } // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] @@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -160,15 +195,21 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + if (!supports_set_rows && !unified) { + LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing LLAMA_SET_ROWS=1\n", __func__); + supports_set_rows = 1; + } + if (!supports_set_rows) { LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); } } void llama_kv_cache_unified::clear(bool data) { - cells.reset(); - - head = 0; + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].reset(); + v_heads[s] = 0; + } if (data) { for (auto & buf : bufs) { @@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); if (p0 < 0) { @@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); + GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); + + const auto s0 = seq_to_stream[seq_id_src]; + const auto s1 = seq_to_stream[seq_id_dst]; + + if (s0 == s1) { + // since both sequences are in the same stream, no data copy is necessary + // we just have to update the cells meta data + + auto & cells = v_cells[s0]; + + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } + return; } - if (p0 < 0) { - p0 = 0; + // cross-stream sequence copies require to copy the actual buffer data + + bool is_full = true; + + if (p0 > 0 && p0 + 1 < (int) get_size()) { + is_full = false; } - if (p1 < 0) { - p1 = std::numeric_limits::max(); + if (p1 > 0 && p1 + 1 < (int) get_size()) { + is_full = false; } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers"); + + // enqueue the copy operation - the buffer copy will be performed during the next update + sc_info.ssrc.push_back(s0); + sc_info.sdst.push_back(s1); + + v_cells[s1].reset(); + for (uint32_t i = 0; i < v_cells[s0].size(); ++i) { + if (v_cells[s0].seq_has(i, seq_id_src)) { + llama_pos pos = v_cells[s0].pos_get(i); + llama_pos shift = v_cells[s0].get_shift(i); + + if (shift != 0) { + pos -= shift; + assert(pos >= 0); + } + + v_cells[s1].pos_set(i, pos); + v_cells[s1].seq_add(i, seq_id_dst); - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); + if (shift != 0) { + v_cells[s1].pos_add(i, shift); + } } } + + v_heads[s1] = v_heads[s0]; + + //for (uint32_t s = 0; s < n_stream; ++s) { + // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s)); + //} } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { @@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + if (shift == 0) { return; } @@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + if (d == 1) { return; } @@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_max(seq_id); } @@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = balloc.split_simple(n_ubatch); + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); if (ubatch.n_tokens == 0) { break; @@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct defrag_info dinfo; // see if we need to defrag - { + if (n_stream == 1) { + // note : for now do not consider defrag for n_stream > 1 + const auto & cells = v_cells[seq_to_stream[0]]; + bool do_defrag = optimize; const auto thold = lctx->get_cparams().defrag_thold; @@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct } } - return std::make_unique(this, lctx, do_shift, std::move(dinfo)); + return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); } llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { llama_kv_cache_unified::slot_info_vec_t res; - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - + struct state_t { slot_info sinfo; // slot info for the ubatch - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + std::vector v_heads_old; // old positions of the heads, before placing the ubatch + + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end - std::vector states; + std::vector states; bool success = true; @@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); + { + state_t state = { sinfo_new, v_heads, {} }; + + for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) { + auto & cells = v_cells[sinfo_new.strm[s]]; + + state.v_cells.push_back(cells.cp(sinfo_new.idxs[s])); + } + + states.push_back(std::move(state)); + } // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); } + GGML_ASSERT(!states.empty() || !success); + // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->sinfo.idxs, it->cells); - head = it->head_old; + const auto & sinfo = it->sinfo; + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & cells = v_cells[sinfo.strm[s]]; + auto & head = v_heads[sinfo.strm[s]]; + + cells.set(sinfo.idxs[s], it->v_cells[s]); + head = it->v_heads_old[s]; + } } if (!success) { @@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st return res; } -bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { +bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) { bool updated = false; auto * sched = lctx->get_sched(); + if (!sc_info.empty()) { + assert(n_stream > 1 && "stream copy should never happen with a single stream"); + + llama_synchronize(lctx); + + const size_t n_copy = sc_info.ssrc.size(); + + for (size_t i = 0; i < n_copy; ++i) { + const auto ssrc = sc_info.ssrc[i]; + const auto sdst = sc_info.sdst[i]; + + assert(ssrc < n_stream); + assert(sdst < n_stream); + + LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst); + + assert(ssrc != sdst); + + for (uint32_t il = 0; il < layers.size(); ++il) { + const auto & layer = layers[il]; + + ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } + } + } + if (do_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); @@ -503,12 +679,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d updated = true; } - cells.reset_shift(); + for (uint32_t s = 0; s < n_stream; ++s) { + auto & cells = v_cells[s]; + + cells.reset_shift(); + } } if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + // note: for now do not consider defrag for n_stream > 1 + auto & cells = v_cells[seq_to_stream[0]]; + auto & head = v_heads[seq_to_stream[0]]; + // apply moves: { const auto n_kv = dinfo.ids.size(); @@ -556,23 +740,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - const uint32_t n_tokens = ubatch.n_tokens; - - uint32_t head_cur = this->head; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } + if (debug > 0) { + const auto & cells = v_cells[seq_to_stream[1]]; - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return { }; - } + const uint32_t head_cur = v_heads[1]; - if (debug > 0) { - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); if ((debug == 2 && n_swa > 0) || debug > 2) { std::string ss; @@ -629,86 +803,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } - uint32_t n_tested = 0; + uint32_t n_tokens = ubatch.n_tokens; + uint32_t n_seqs = 1; - // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head - // for non-continuous slots, we test the tokens one by one - const uint32_t n_test = cont ? n_tokens : 1; + if (n_stream > 1) { + GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0); - slot_info res; + n_seqs = ubatch.n_seqs_unq; + n_tokens = n_tokens / n_seqs; + } + + slot_info res = { + /*.s0 =*/ LLAMA_MAX_SEQ, + /*.s1 =*/ 0, + /*.strm =*/ { }, + /*.idxs =*/ { }, + }; + + res.resize(n_seqs); + + for (uint32_t s = 0; s < n_seqs; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + + if (n_stream > 1) { + GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1); + GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); + } + + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); + + res.strm[s] = seq_to_stream[seq_id]; + res.idxs[s].reserve(n_tokens); - auto & idxs = res.idxs; + const auto & cells = v_cells[seq_to_stream[seq_id]]; - idxs.reserve(n_tokens); + uint32_t head_cur = v_heads[seq_to_stream[seq_id]]; - while (true) { - if (head_cur + n_test > cells.size()) { - n_tested += cells.size() - head_cur; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*n_tokens) { head_cur = 0; - continue; } - for (uint32_t i = 0; i < n_test; i++) { - const auto idx = head_cur; + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return { }; + } + + uint32_t n_tested = 0; + + // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head + // for non-continuous slots, we test the tokens one by one + const uint32_t n_test = cont ? n_tokens : 1; + + while (true) { + if (head_cur + n_test > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + + head_cur++; + n_tested++; - //const llama_pos pos = ubatch.pos[i]; - //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - (disabled) mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(idx); + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(idx); - if (!can_use && cells.seq_count(idx) == 1) { - const llama_pos pos_cell = cells.pos_get(idx); + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); - // (disabled) causal mask - // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(idx, seq_id)) { - // can_use = pos_cell >= pos; - //} + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(idx, seq_id)) { + // can_use = pos_cell >= pos; + //} - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(idx); + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(idx); - // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - can_use = true; + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } } } - } - head_cur++; - n_tested++; + if (can_use) { + res.idxs[s].push_back(idx); + } else { + if (cont) { + break; + } + } + } - if (can_use) { - idxs.push_back(idx); - } else { + if (res.idxs[s].size() == n_tokens) { break; } - } - if (idxs.size() == n_tokens) { - break; - } + if (cont) { + res.idxs[s].clear(); + } - if (cont) { - idxs.clear(); + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return { }; + } } - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + // we didn't find a suitable slot - return empty result + if (res.idxs[s].size() < n_tokens) { return { }; } } - // we didn't find a suitable slot - return empty result - if (idxs.size() < n_tokens) { - res.clear(); - } + assert(res.s1 >= res.s0); return res; } @@ -717,41 +938,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } - assert(ubatch.n_tokens == sinfo.idxs.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - const auto idx = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + auto & cells = v_cells[sinfo.strm[s]]; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const auto idx = sinfo.idxs[s][ii]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - cells.rm(idx); - } + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); + + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - cells.pos_set(idx, ubatch.pos[i]); + cells.rm(idx); + } + + cells.pos_set(idx, ubatch.pos[i]); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } + GGML_ASSERT(s < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[s]]; + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); @@ -761,7 +992,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } // move the head at the end of the slot - head = sinfo.idxs.back() + 1; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; + + head = sinfo.idxs[s].back() + 1; + } } bool llama_kv_cache_unified::get_can_shift() const { @@ -769,49 +1004,87 @@ bool llama_kv_cache_unified::get_can_shift() const { } uint32_t llama_kv_cache_unified::get_size() const { + const auto & cells = v_cells[seq_to_stream[0]]; + return cells.size(); } +uint32_t llama_kv_cache_unified::get_n_stream() const { + return n_stream; +} + bool llama_kv_cache_unified::get_has_shift() const { - return cells.get_has_shift(); + bool result = false; + + for (uint32_t s = 0; s < n_stream; ++s) { + result |= v_cells[s].get_has_shift(); + } + + return result; } uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + uint32_t result = 0; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + } + + return result; } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + const uint64_t kv_size = get_size(); + const uint64_t n_embd_k_gqa = k->ne[0]; + + assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + + return ggml_view_4d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa*kv_size), + ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_v_gqa = v->ne[0]; + + // [TAG_V_CACHE_VARIABLE] + assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (!v_trans) { // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { @@ -825,12 +1098,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); if (k_idxs && supports_set_rows) { + if (k->ne[2] > 1) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + } + return ggml_set_rows(ctx, k, k_cur, k_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); @@ -843,37 +1122,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v->ne[0]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); if (v_idxs && supports_set_rows) { if (!v_trans) { + if (v->ne[2] > 1) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + } + return ggml_set_rows(ctx, v, v_cur, v_idxs); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_v_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + } - // note: the V cache is transposed when not using flash attention - v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - // note: we can be more explicit here at the cost of extra cont - // however, above we take advantage that a row of single element is always continuous regardless of the row stride - //v_cur = ggml_transpose(ctx, v_cur); - //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - // we broadcast the KV indices n_embd_v_gqa times - // v [1, n_kv, n_embd_v_gqa] - // v_cur [1, n_tokens, n_embd_v_gqa] - // v_idxs [n_tokens, 1, 1] return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * v_view = nullptr; if (!v_trans) { @@ -904,7 +1184,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; - ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + ggml_tensor * v_idxs; + + if (!v_trans) { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + } else { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max()); + } ggml_set_input(v_idxs); @@ -917,12 +1203,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } } } @@ -932,12 +1223,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + if (!v_trans) { + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } + } + } else { + // note: the V cache is transposed when not using flash attention + const int64_t kv_size = get_size(); + + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max(); + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa; + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i]; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } } } @@ -947,7 +1274,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = dst->ne[0]; + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -962,67 +1296,66 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // xxxxx----- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + for (uint32_t s = 0; s < n_stream; ++s) { + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; - const llama_pos p1 = ubatch->pos[i]; + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; + const auto & cells = v_cells[seq_to_stream[seq_id]]; - bool masked = false; + const llama_pos p1 = ubatch->pos[i]; - if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); + for (uint32_t j = 0; j < n_kv; ++j) { + float f = 0.0f; - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); + bool masked = false; - // mask future tokens - masked = masked || (causal_attn && p0 > p1); + if (cells.is_empty(j)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(j); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(j, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } } - } - if (masked) { - f = -INFINITY; - } + if (masked) { + f = -INFINITY; + } - data[h*(n_kv*n_tokens) + i*n_kv + j] = f; - } - } + data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f; + } - // mask padded tokens - if (data) { - for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + // mask padded tokens + if (data) { + for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY; + } + } } } } } } -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); - } -} - void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams"); + const auto & cells = v_cells[0]; + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing @@ -1129,7 +1462,7 @@ class llm_graph_input_k_shift : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * k_shift; // I32 [kv_size] + ggml_tensor * k_shift; // I32 [kv_size*n_stream] const llama_kv_cache_unified * kv_self; }; @@ -1153,7 +1486,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); for (const auto & layer : layers) { @@ -1169,7 +1502,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), + n_embd_head_k, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1191,6 +1524,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( const defrag_info & dinfo) const { auto res = std::make_unique(); + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const auto & ids = dinfo.ids; #if 0 @@ -1333,6 +1670,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( } llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1478,64 +1819,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { } void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + io.write(&n_stream, sizeof(n_stream)); - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = cells.size(); + for (uint32_t s = 0; s < n_stream; ++s) { + cell_ranges_t cr { s, {} }; - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { - ++cell_count; - if (cell_range_begin == cells.size()) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = cells.size(); + uint32_t cell_count = 0; + + const auto & cells = v_cells[s]; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } } } - } - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, cells.size()); - } + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, cells.size()); + } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cr.data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); - io.write(&cell_count, sizeof(cell_count)); + io.write(&cell_count, sizeof(cell_count)); - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + // skip empty streams + if (cell_count == 0) { + continue; + } + + state_write_meta(io, cr, seq_id); + state_write_data(io, cr); + } } void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); + uint32_t n_stream_cur; + io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + if (n_stream_cur != n_stream) { + throw std::runtime_error("n_stream mismatch"); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + if (cell_count == 0) { + continue; + } + + const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; - if (!res) { - if (seq_id == -1) { - clear(true); - } else { - seq_rm(seq_id, -1, -1); + bool res = true; + res = res && state_read_meta(io, strm, cell_count, seq_id); + res = res && state_read_data(io, strm, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); } - throw std::runtime_error("failed to restore kv cache"); } } -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { + const auto & cells = v_cells[cr.strm]; + + for (const auto & range : cr.data) { for (uint32_t i = range.first; i < range.second; ++i) { std::vector seq_ids; @@ -1560,7 +1931,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } } -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { + const auto & cells = v_cells[cr.strm]; + const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = layers.size(); @@ -1576,19 +1949,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[cr.strm]; + // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; + const int32_t k_type_i = (int32_t) k->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); + io.write_tensor(k, range.first * k_size_row, buf_size); } } @@ -1598,19 +1973,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); + io.write_tensor(v, range.first * v_size_row, buf_size); } } } else { @@ -1622,12 +1999,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); + const uint32_t v_size_el = ggml_type_size(v->type); io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size @@ -1636,27 +2015,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); + io.write_tensor(v, src_offset, buf_size); } } } } } -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); + ubatch.seq_id_unq[0] = dest_seq_id; + for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; uint32_t n_seq_id; @@ -1693,6 +2076,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell // keep the head at the old position because we will read the KV data into it in state_read_data() head = head_cur; + LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells GGML_ASSERT(head_cur + cell_count <= cells.size()); @@ -1738,7 +2123,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + uint32_t v_trans; uint32_t n_layer; @@ -1766,10 +2154,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[strm]; + // Read type of key int32_t k_type_i_ref; io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; + const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return false; @@ -1778,7 +2168,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -1786,7 +2176,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); } } @@ -1796,10 +2186,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1808,7 +2200,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -1816,7 +2208,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); } } } else { @@ -1826,10 +2218,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1838,7 +2232,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t v_size_el_ref; io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); + const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); return false; @@ -1856,7 +2250,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } @@ -1875,18 +2269,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); + const uint32_t n_stream = kv->get_n_stream(); + // create a dummy slot info - the actual data is irrelevant. we just need to build the graph sinfos.resize(1); - sinfos[0].idxs.resize(1); - sinfos[0].idxs[0] = 0; + sinfos[0].s0 = 0; + sinfos[0].s1 = n_stream - 1; + sinfos[0].idxs.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + sinfos[0].strm.push_back(s); + sinfos[0].idxs[s].resize(1, 0); + } } llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { - if (!do_shift && this->dinfo.empty()) { + defrag_info dinfo, + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { + if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } @@ -1914,7 +2316,7 @@ bool llama_kv_cache_unified_context::apply() { // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo); + kv->update(lctx, do_shift, dinfo, sc_info); return true; } @@ -1941,11 +2343,11 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const { } ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { - return kv->get_k(ctx, il, n_kv); + return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { - return kv->get_v(ctx, il, n_kv); + return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index b8b0356e830c8..3bfda4600d843 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -35,16 +35,50 @@ class llama_kv_cache_unified : public llama_memory_i { std::vector ids; }; + struct stream_copy_info { + bool empty() const { + assert(ssrc.size() == sdst.size()); + return ssrc.empty(); + } + + std::vector ssrc; + std::vector sdst; + }; + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] struct slot_info { // data for ggml_set_rows using idx_vec_t = std::vector; - idx_vec_t idxs; + // number of streams: ns = s1 - s0 + 1 + llama_seq_id s0; + llama_seq_id s1; + + std::vector strm; // [ns] + std::vector idxs; // [ns] uint32_t head() const { - return idxs.at(0); + GGML_ASSERT(idxs.size() == 1); + GGML_ASSERT(!idxs[0].empty()); + + return idxs[0][0]; + } + + void resize(size_t n) { + strm.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == strm.size()); + GGML_ASSERT(!idxs.empty()); + + return idxs[0].size(); + } + + size_t n_stream() const { + return strm.size(); } bool empty() const { @@ -54,9 +88,6 @@ class llama_kv_cache_unified : public llama_memory_i { void clear() { idxs.clear(); } - - // TODO: implement - //std::vector seq_idxs; }; using slot_info_vec_t = std::vector; @@ -68,6 +99,7 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, @@ -111,7 +143,8 @@ class llama_kv_cache_unified : public llama_memory_i { // llama_kv_cache_unified specific API // - uint32_t get_size() const; + uint32_t get_size() const; + uint32_t get_n_stream() const; bool get_has_shift() const; @@ -122,8 +155,8 @@ class llama_kv_cache_unified : public llama_memory_i { uint32_t get_n_kv() const; // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; @@ -137,7 +170,7 @@ class llama_kv_cache_unified : public llama_memory_i { // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous @@ -157,8 +190,9 @@ class llama_kv_cache_unified : public llama_memory_i { void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: @@ -172,15 +206,15 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_tensor * k; ggml_tensor * v; + + std::vector k_stream; + std::vector v_stream; }; bool v_trans = true; // the value tensor is transposed - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - const uint32_t n_seq_max = 1; + const uint32_t n_stream = 1; // required padding const uint32_t n_pad = 1; @@ -200,7 +234,17 @@ class llama_kv_cache_unified : public llama_memory_i { std::vector ctxs; std::vector bufs; - llama_kv_cells_unified cells; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + std::vector v_heads; + + std::vector v_cells; + + // maps from a sequence id to a stream id + std::vector seq_to_stream; + + // pending stream copies that will be applied during the next update + stream_copy_info sc_info; std::vector layers; @@ -237,18 +281,25 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_cgraph * gf, const defrag_info & dinfo) const; - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + struct cell_ranges_t { + uint32_t strm; - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); + std::vector> data; // ranges, from inclusive, to exclusive + }; + + void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); }; class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; - using defrag_info = llama_kv_cache_unified::defrag_info; + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using defrag_info = llama_kv_cache_unified::defrag_info; + using stream_copy_info = llama_kv_cache_unified::stream_copy_info; // used for errors llama_kv_cache_unified_context(llama_memory_status status); @@ -262,7 +313,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo); + defrag_info dinfo, + stream_copy_info sc_info); // used to create a batch procesing context from a batch llama_kv_cache_unified_context( @@ -320,6 +372,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { defrag_info dinfo; + stream_copy_info sc_info; + // // batch processing context // diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 6cd10db06b775..eedfaec53e876 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, kv_size, n_seq_max, + 1, n_pad, n_swa, swa_type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a322fc39352e7..42abea2d88abb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -16118,7 +16118,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { const auto padding = llama_kv_cache_unified::get_padding(cparams); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + uint32_t n_ctx_per_stream = cparams.n_ctx; + + if (!cparams.kv_unified) { + n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max; + } else { + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream; + } LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -16132,7 +16143,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, params.swa_full, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, padding); @@ -16146,7 +16158,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.type_v, !cparams.flash_attn, cparams.offload_kqv, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, padding, hparams.n_swa, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9f6204834d5bf..bd1e3c3091a75 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]); ggml_set_name(m, "m"); } diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index a0a2e5ac56ea9..03628f74b2880 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -127,10 +127,9 @@ int main(int argc, char ** argv) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { for (int i = 0; i < pp; ++i) { - common_batch_add(batch, 0, i, { j }, false); + common_batch_add(batch, 0, i, { j }, i == pp - 1); } } - batch.logits[batch.n_tokens - 1] = true; const auto t_pp_start = ggml_time_us();