Skip to content

Commit 386425f

Browse files
committed
cont : handle variable V heads
ggml-ci
1 parent 40f8c48 commit 386425f

File tree

3 files changed

+57
-30
lines changed

3 files changed

+57
-30
lines changed

src/llama-hparams.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,26 +65,44 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68-
bool llama_hparams::is_n_embd_k_gqa_homogeneous() const {
69-
uint32_t val = n_embd_k_gqa();
68+
bool llama_hparams::is_n_embd_k_gqa_variable() const {
69+
const uint32_t val = n_embd_k_gqa();
7070
for (uint32_t il = 0; il < n_layer; ++il) {
7171
if (val != n_embd_k_gqa(il)) {
72-
return false;
72+
return true;
7373
}
7474
}
7575

76-
return true;
76+
return false;
7777
}
7878

79-
bool llama_hparams::is_n_embd_v_gqa_homogeneous() const {
80-
uint32_t val = n_embd_v_gqa();
79+
bool llama_hparams::is_n_embd_v_gqa_variable() const {
80+
const uint32_t val = n_embd_v_gqa();
8181
for (uint32_t il = 0; il < n_layer; ++il) {
8282
if (val != n_embd_v_gqa(il)) {
83-
return false;
83+
return true;
8484
}
8585
}
8686

87-
return true;
87+
return false;
88+
}
89+
90+
uint32_t llama_hparams::n_embd_k_gqa_max() const {
91+
uint32_t val = n_embd_k_gqa();
92+
for (uint32_t il = 0; il < n_layer; ++il) {
93+
val = std::max(val, n_embd_k_gqa(il));
94+
}
95+
96+
return val;
97+
}
98+
99+
uint32_t llama_hparams::n_embd_v_gqa_max() const {
100+
uint32_t val = n_embd_v_gqa();
101+
for (uint32_t il = 0; il < n_layer; ++il) {
102+
val = std::max(val, n_embd_v_gqa(il));
103+
}
104+
105+
return val;
88106
}
89107

90108
uint32_t llama_hparams::n_embd_r() const {

src/llama-hparams.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,13 @@ struct llama_hparams {
189189
// dimension of value embeddings across all k-v heads
190190
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
191191

192-
// true if all layers have the same n_embd_k_gqa/n_embd_v_gqa
193-
bool is_n_embd_k_gqa_homogeneous() const;
194-
bool is_n_embd_v_gqa_homogeneous() const;
192+
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
193+
bool is_n_embd_k_gqa_variable() const;
194+
bool is_n_embd_v_gqa_variable() const;
195+
196+
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
197+
uint32_t n_embd_k_gqa_max() const;
198+
uint32_t n_embd_v_gqa_max() const;
195199

196200
// dimension of the rolling state embeddings
197201
// corresponds to Mamba's conv_states size or RWKV's token_shift states size

src/llama-kv-cache-unified.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868

6969
cells.resize(kv_size);
7070

71-
if (supports_set_rows) {
72-
// TODO: this requirement can be relaxed, but it would be much easier to implement when we have an actual
73-
// model that needs this
74-
// ref: https://github.com/ggml-org/llama.cpp/pull/14517
75-
GGML_ASSERT(hparams.is_n_embd_v_gqa_homogeneous());
71+
// [TAG_V_CACHE_VARIABLE]
72+
if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
73+
LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
74+
__func__, hparams.n_embd_v_gqa_max());
7675
}
7776

7877
for (uint32_t il = 0; il < n_layer_cache; il++) {
@@ -81,8 +80,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
8180
continue;
8281
}
8382

84-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
85-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
83+
// [TAG_V_CACHE_VARIABLE]
84+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
85+
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
8686

8787
const char * dev_name = "CPU";
8888

@@ -808,19 +808,19 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808808
// note: v->nb[1] <= v->nb[2]
809809
return ggml_view_4d(ctx, v,
810810
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1,
811-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
812-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
813-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*kv_size), // v->nb[3]
814-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*kv_size)*0);
811+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
812+
ggml_row_size(v->type, v->ne[0]), // v->nb[2]
813+
ggml_row_size(v->type, v->ne[0]*kv_size), // v->nb[3]
814+
ggml_row_size(v->type, v->ne[0]*kv_size)*0);
815815
}
816816

817817
// note: v->nb[1] > v->nb[2]
818818
return ggml_view_4d(ctx, v,
819819
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1,
820-
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
821-
ggml_row_size(v->type, kv_size), // v->nb[2]
822-
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
823-
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il))*0);
820+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
821+
ggml_row_size(v->type, kv_size), // v->nb[2]
822+
ggml_row_size(v->type, kv_size*v->ne[0]), // v->nb[3]
823+
ggml_row_size(v->type, kv_size*v->ne[0])*0);
824824
}
825825

826826
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -856,8 +856,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
856856

857857
auto * v = layers[ikv].v;
858858

859-
const int64_t n_embd_v_gqa = v->ne[0];
860-
const int64_t n_tokens = v_cur->ne[2];
859+
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
860+
const int64_t n_tokens = v_cur->ne[2];
861861

862862
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
863863

@@ -870,6 +870,11 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
870870
return ggml_set_rows(ctx, v, v_cur, v_idxs);
871871
}
872872

873+
// [TAG_V_CACHE_VARIABLE]
874+
if (n_embd_v_gqa < v->ne[0]) {
875+
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
876+
}
877+
873878
// the row becomes a single element
874879
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
875880

@@ -916,7 +921,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
916921
if (!v_trans) {
917922
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
918923
} else {
919-
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa());
924+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
920925
}
921926

922927
ggml_set_input(v_idxs);
@@ -957,7 +962,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
957962
// note: the V cache is transposed when not using flash attention
958963
const int64_t kv_size = get_size();
959964

960-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
965+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
961966

962967
for (uint32_t i = 0; i < n_tokens; ++i) {
963968
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {

0 commit comments

Comments
 (0)