@@ -68,11 +68,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
68
68
69
69
cells.resize (kv_size);
70
70
71
- if (supports_set_rows) {
72
- // TODO: this requirement can be relaxed, but it would be much easier to implement when we have an actual
73
- // model that needs this
74
- // ref: https://github.com/ggml-org/llama.cpp/pull/14517
75
- GGML_ASSERT (hparams.is_n_embd_v_gqa_homogeneous ());
71
+ // [TAG_V_CACHE_VARIABLE]
72
+ if (v_trans && hparams.is_n_embd_v_gqa_variable ()) {
73
+ LLAMA_LOG_WARN (" %s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n " ,
74
+ __func__, hparams.n_embd_v_gqa_max ());
76
75
}
77
76
78
77
for (uint32_t il = 0 ; il < n_layer_cache; il++) {
@@ -81,8 +80,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
81
80
continue ;
82
81
}
83
82
84
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
85
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
83
+ // [TAG_V_CACHE_VARIABLE]
84
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
85
+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa (il) : hparams.n_embd_v_gqa_max ();
86
86
87
87
const char * dev_name = " CPU" ;
88
88
@@ -808,19 +808,19 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808
808
// note: v->nb[1] <= v->nb[2]
809
809
return ggml_view_4d (ctx, v,
810
810
hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, 1 ,
811
- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
812
- ggml_row_size (v->type , hparams. n_embd_v_gqa (il)), // v->nb[2]
813
- ggml_row_size (v->type , hparams. n_embd_v_gqa (il) *kv_size), // v->nb[3]
814
- ggml_row_size (v->type , hparams. n_embd_v_gqa (il) *kv_size)*0 );
811
+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
812
+ ggml_row_size (v->type , v-> ne [ 0 ]), // v->nb[2]
813
+ ggml_row_size (v->type , v-> ne [ 0 ] *kv_size), // v->nb[3]
814
+ ggml_row_size (v->type , v-> ne [ 0 ] *kv_size)*0 );
815
815
}
816
816
817
817
// note: v->nb[1] > v->nb[2]
818
818
return ggml_view_4d (ctx, v,
819
819
n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , 1 ,
820
- ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
821
- ggml_row_size (v->type , kv_size), // v->nb[2]
822
- ggml_row_size (v->type , kv_size*hparams. n_embd_v_gqa (il)), // v->nb[3]
823
- ggml_row_size (v->type , kv_size*hparams. n_embd_v_gqa (il) )*0 );
820
+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
821
+ ggml_row_size (v->type , kv_size), // v->nb[2]
822
+ ggml_row_size (v->type , kv_size*v-> ne [ 0 ]), // v->nb[3]
823
+ ggml_row_size (v->type , kv_size*v-> ne [ 0 ] )*0 );
824
824
}
825
825
826
826
ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -856,8 +856,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
856
856
857
857
auto * v = layers[ikv].v ;
858
858
859
- const int64_t n_embd_v_gqa = v ->ne [0 ];
860
- const int64_t n_tokens = v_cur->ne [2 ];
859
+ const int64_t n_embd_v_gqa = v_cur ->ne [0 ]*v_cur-> ne [ 1 ];
860
+ const int64_t n_tokens = v_cur->ne [2 ];
861
861
862
862
v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
863
863
@@ -870,6 +870,11 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
870
870
return ggml_set_rows (ctx, v, v_cur, v_idxs);
871
871
}
872
872
873
+ // [TAG_V_CACHE_VARIABLE]
874
+ if (n_embd_v_gqa < v->ne [0 ]) {
875
+ v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa, 0 , 0 , 0 );
876
+ }
877
+
873
878
// the row becomes a single element
874
879
ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v->ne [0 ]*v->ne [1 ]*v->ne [2 ]);
875
880
@@ -916,7 +921,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
916
921
if (!v_trans) {
917
922
v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
918
923
} else {
919
- v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
924
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max ());
920
925
}
921
926
922
927
ggml_set_input (v_idxs);
@@ -957,7 +962,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
957
962
// note: the V cache is transposed when not using flash attention
958
963
const int64_t kv_size = get_size ();
959
964
960
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa ();
965
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max ();
961
966
962
967
for (uint32_t i = 0 ; i < n_tokens; ++i) {
963
968
for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
0 commit comments