Skip to content

Commit 8c68219

Browse files
committed
kv-cache : fix non-FA path with virutal sequences
ggml-ci
1 parent 7c6487b commit 8c68219

File tree

1 file changed

+43
-21
lines changed

1 file changed

+43
-21
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
803803
}
804804
}
805805

806+
assert(res.s1 >= res.s0);
807+
806808
return res;
807809
}
808810

@@ -908,13 +910,8 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
908910

909911
auto * k = layers[ikv].k;
910912

911-
assert(sinfo.s1 >= sinfo.s0);
912-
913913
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
914914

915-
assert(ns > 0);
916-
assert(ns <= n_seq_virt);
917-
918915
const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size());
919916

920917
return ggml_view_4d(ctx, k,
@@ -932,9 +929,6 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
932929

933930
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
934931

935-
assert(ns > 0);
936-
assert(ns <= n_seq_virt);
937-
938932
const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size());
939933

940934
if (!v_trans) {
@@ -967,9 +961,20 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
967961
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
968962

969963
if (kv_idxs && supports_set_rows) {
970-
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
964+
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
965+
966+
const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size());
967+
968+
ggml_tensor * k_view = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], ns,
969+
ggml_row_size(k->type, k->ne[0]),
970+
size_virt,
971+
size_virt*sinfo.s0);
972+
973+
k_cur = ggml_reshape_3d(ctx, k_cur, k_cur->ne[0], k_cur->ne[1]/ns, ns);
974+
975+
kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns);
971976

972-
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
977+
return ggml_set_rows(ctx, k_view, k_cur, kv_idxs);
973978
}
974979

975980
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -995,27 +1000,46 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
9951000
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
9961001

9971002
if (kv_idxs && supports_set_rows) {
1003+
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1004+
1005+
const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size());
1006+
9981007
if (!v_trans) {
999-
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1008+
ggml_tensor * v_view = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], ns,
1009+
ggml_row_size(v->type, v->ne[0]),
1010+
size_virt,
1011+
size_virt*sinfo.s0);
1012+
1013+
v_cur = ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], v_cur->ne[1]/ns, ns);
10001014

1001-
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
1015+
kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns);
1016+
1017+
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
10021018
}
10031019

10041020
// the row becomes a single element
1005-
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]);
1021+
ggml_tensor * v_view = ggml_view_4d(ctx, v, 1, v->ne[1], v->ne[0], ns,
1022+
ggml_row_size(v->type, 1),
1023+
ggml_row_size(v->type, v->ne[1]),
1024+
size_virt,
1025+
size_virt*sinfo.s0);
10061026

10071027
// note: the V cache is transposed when not using flash attention
1008-
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
1028+
v_cur = ggml_permute(ctx, ggml_reshape_4d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]/ns, ns), 2, 0, 1, 3);
10091029

10101030
// note: we can be more explicit here at the cost of extra cont
10111031
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
1032+
//v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
10121033
//v_cur = ggml_transpose(ctx, v_cur);
1013-
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
1034+
//v_cur = ggml_cont_4d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2]);
10141035

10151036
// we broadcast the KV indices n_embd_v_gqa times
1016-
// v [1, n_kv, n_embd_v_gqa]
1017-
// v_cur [1, n_tokens, n_embd_v_gqa]
1018-
// kv_idxs [n_tokens, 1, 1]
1037+
// v [1, n_kv, n_embd_v_gqa, ns]
1038+
// v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
1039+
// kv_idxs [n_tokens/ns, 1, ns]
1040+
1041+
kv_idxs = ggml_reshape_3d(ctx, kv_idxs, n_tokens/ns, 1, ns);
1042+
10191043
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
10201044
}
10211045

@@ -1053,10 +1077,8 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
10531077
int64_t * data = (int64_t *) dst->data;
10541078

10551079
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
1056-
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
1057-
10581080
for (uint32_t i = 0; i < sinfo.size(); ++i) {
1059-
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1081+
data[s*sinfo.size() + i] = sinfo.idxs[s][i];
10601082
}
10611083
}
10621084
}

0 commit comments

Comments
 (0)