@@ -803,6 +803,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
803
803
}
804
804
}
805
805
806
+ assert (res.s1 >= res.s0 );
807
+
806
808
return res;
807
809
}
808
810
@@ -908,13 +910,8 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
908
910
909
911
auto * k = layers[ikv].k ;
910
912
911
- assert (sinfo.s1 >= sinfo.s0 );
912
-
913
913
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
914
914
915
- assert (ns > 0 );
916
- assert (ns <= n_seq_virt);
917
-
918
915
const uint64_t size_virt = ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*get_size ());
919
916
920
917
return ggml_view_4d (ctx, k,
@@ -932,9 +929,6 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
932
929
933
930
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
934
931
935
- assert (ns > 0 );
936
- assert (ns <= n_seq_virt);
937
-
938
932
const uint64_t size_virt = ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*get_size ());
939
933
940
934
if (!v_trans) {
@@ -967,9 +961,20 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
967
961
k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
968
962
969
963
if (kv_idxs && supports_set_rows) {
970
- k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
964
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
965
+
966
+ const uint64_t size_virt = ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*get_size ());
967
+
968
+ ggml_tensor * k_view = ggml_view_3d (ctx, k, k->ne [0 ], k->ne [1 ], ns,
969
+ ggml_row_size (k->type , k->ne [0 ]),
970
+ size_virt,
971
+ size_virt*sinfo.s0 );
972
+
973
+ k_cur = ggml_reshape_3d (ctx, k_cur, k_cur->ne [0 ], k_cur->ne [1 ]/ns, ns);
974
+
975
+ kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
971
976
972
- return ggml_set_rows (ctx, k , k_cur, kv_idxs);
977
+ return ggml_set_rows (ctx, k_view , k_cur, kv_idxs);
973
978
}
974
979
975
980
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -995,27 +1000,46 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
995
1000
v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
996
1001
997
1002
if (kv_idxs && supports_set_rows) {
1003
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
1004
+
1005
+ const uint64_t size_virt = ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*get_size ());
1006
+
998
1007
if (!v_trans) {
999
- v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1008
+ ggml_tensor * v_view = ggml_view_3d (ctx, v, v->ne [0 ], v->ne [1 ], ns,
1009
+ ggml_row_size (v->type , v->ne [0 ]),
1010
+ size_virt,
1011
+ size_virt*sinfo.s0 );
1012
+
1013
+ v_cur = ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], v_cur->ne [1 ]/ns, ns);
1000
1014
1001
- return ggml_set_rows (ctx, v, v_cur, kv_idxs);
1015
+ kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
1016
+
1017
+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
1002
1018
}
1003
1019
1004
1020
// the row becomes a single element
1005
- ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ]*v->ne [2 ], v->ne [0 ]);
1021
+ ggml_tensor * v_view = ggml_view_4d (ctx, v, 1 , v->ne [1 ], v->ne [0 ], ns,
1022
+ ggml_row_size (v->type , 1 ),
1023
+ ggml_row_size (v->type , v->ne [1 ]),
1024
+ size_virt,
1025
+ size_virt*sinfo.s0 );
1006
1026
1007
1027
// note: the V cache is transposed when not using flash attention
1008
- v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
1028
+ v_cur = ggml_permute (ctx, ggml_reshape_4d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]/ns, ns ), 2 , 0 , 1 , 3 );
1009
1029
1010
1030
// note: we can be more explicit here at the cost of extra cont
1011
1031
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
1032
+ // v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
1012
1033
// v_cur = ggml_transpose(ctx, v_cur);
1013
- // v_cur = ggml_cont_3d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
1034
+ // v_cur = ggml_cont_4d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2 ]);
1014
1035
1015
1036
// we broadcast the KV indices n_embd_v_gqa times
1016
- // v [1, n_kv, n_embd_v_gqa]
1017
- // v_cur [1, n_tokens, n_embd_v_gqa]
1018
- // kv_idxs [n_tokens, 1, 1]
1037
+ // v [1, n_kv, n_embd_v_gqa, ns]
1038
+ // v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
1039
+ // kv_idxs [n_tokens/ns, 1, ns]
1040
+
1041
+ kv_idxs = ggml_reshape_3d (ctx, kv_idxs, n_tokens/ns, 1 , ns);
1042
+
1019
1043
return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
1020
1044
}
1021
1045
@@ -1053,10 +1077,8 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
1053
1077
int64_t * data = (int64_t *) dst->data ;
1054
1078
1055
1079
for (uint32_t s = 0 ; s < sinfo.n_seq_virt (); ++s) {
1056
- const int64_t offs = sinfo.seq_id_virt [s]*get_size ();
1057
-
1058
1080
for (uint32_t i = 0 ; i < sinfo.size (); ++i) {
1059
- data[s*sinfo.size () + i] = offs + sinfo.idxs [s][i];
1081
+ data[s*sinfo.size () + i] = sinfo.idxs [s][i];
1060
1082
}
1061
1083
}
1062
1084
}
0 commit comments