@@ -808,7 +808,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808
808
0 );
809
809
}
810
810
811
- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs , int32_t il, const slot_info & sinfo) const {
811
+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs , int32_t il, const slot_info & sinfo) const {
812
812
const int32_t ikv = map_layer_ids.at (il);
813
813
814
814
auto * k = layers[ikv].k ;
@@ -818,8 +818,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
818
818
819
819
k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
820
820
821
- if (kv_idxs && supports_set_rows) {
822
- return ggml_set_rows (ctx, k, k_cur, kv_idxs );
821
+ if (k_idxs && supports_set_rows) {
822
+ return ggml_set_rows (ctx, k, k_cur, k_idxs );
823
823
}
824
824
825
825
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -832,7 +832,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
832
832
return ggml_cpy (ctx, k_cur, k_view);
833
833
}
834
834
835
- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs , int32_t il, const slot_info & sinfo) const {
835
+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs , int32_t il, const slot_info & sinfo) const {
836
836
const int32_t ikv = map_layer_ids.at (il);
837
837
838
838
auto * v = layers[ikv].v ;
@@ -842,9 +842,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
842
842
843
843
v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
844
844
845
- if (kv_idxs && supports_set_rows) {
845
+ if (v_idxs && supports_set_rows) {
846
846
if (!v_trans) {
847
- return ggml_set_rows (ctx, v, v_cur, kv_idxs );
847
+ return ggml_set_rows (ctx, v, v_cur, v_idxs );
848
848
}
849
849
850
850
// the row becomes a single element
@@ -859,10 +859,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
859
859
// v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
860
860
861
861
// we broadcast the KV indices n_embd_v_gqa times
862
- // v [1, n_kv, n_embd_v_gqa]
863
- // v_cur [1, n_tokens, n_embd_v_gqa]
864
- // kv_idxs [n_tokens, 1, 1]
865
- return ggml_set_rows (ctx, v_view, v_cur, kv_idxs );
862
+ // v [1, n_kv, n_embd_v_gqa]
863
+ // v_cur [1, n_tokens, n_embd_v_gqa]
864
+ // v_idxs [n_tokens, 1, 1]
865
+ return ggml_set_rows (ctx, v_view, v_cur, v_idxs );
866
866
}
867
867
868
868
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -885,7 +885,49 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
885
885
return ggml_cpy (ctx, v_cur, v_view);
886
886
}
887
887
888
- void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
888
+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
889
+ const uint32_t n_tokens = ubatch.n_tokens ;
890
+
891
+ ggml_tensor * k_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
892
+
893
+ ggml_set_input (k_idxs);
894
+
895
+ return k_idxs;
896
+ }
897
+
898
+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
899
+ const uint32_t n_tokens = ubatch.n_tokens ;
900
+
901
+ ggml_tensor * v_idxs;
902
+
903
+ if (!v_trans) {
904
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
905
+ } else {
906
+ // TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
907
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
908
+ }
909
+
910
+ ggml_set_input (v_idxs);
911
+
912
+ return v_idxs;
913
+ }
914
+
915
+ void llama_kv_cache_unified::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
916
+ if (!supports_set_rows) {
917
+ return ;
918
+ }
919
+
920
+ const uint32_t n_tokens = ubatch->n_tokens ;
921
+
922
+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
923
+ int64_t * data = (int64_t *) dst->data ;
924
+
925
+ for (int64_t i = 0 ; i < n_tokens; ++i) {
926
+ data[i] = sinfo.idxs [i];
927
+ }
928
+ }
929
+
930
+ void llama_kv_cache_unified::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
889
931
if (!supports_set_rows) {
890
932
return ;
891
933
}
@@ -1906,20 +1948,32 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
1906
1948
return kv->get_v (ctx, il, n_kv);
1907
1949
}
1908
1950
1909
- ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1910
- return kv->cpy_k (ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
1951
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1952
+ return kv->cpy_k (ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1953
+ }
1954
+
1955
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1956
+ return kv->cpy_v (ctx, v_cur, v_idxs, il, sinfos[i_cur]);
1911
1957
}
1912
1958
1913
- ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1914
- return kv->cpy_v (ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
1959
+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
1960
+ return kv->build_input_k_idxs (ctx, ubatch);
1961
+ }
1962
+
1963
+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
1964
+ return kv->build_input_v_idxs (ctx, ubatch);
1915
1965
}
1916
1966
1917
1967
void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
1918
1968
kv->set_input_k_shift (dst);
1919
1969
}
1920
1970
1921
- void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1922
- kv->set_input_kv_idxs (dst, ubatch, sinfos[i_cur]);
1971
+ void llama_kv_cache_unified_context::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1972
+ kv->set_input_k_idxs (dst, ubatch, sinfos[i_cur]);
1973
+ }
1974
+
1975
+ void llama_kv_cache_unified_context::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1976
+ kv->set_input_v_idxs (dst, ubatch, sinfos[i_cur]);
1923
1977
}
1924
1978
1925
1979
void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
0 commit comments