Skip to content

Commit 3d930a9

Browse files
committed
graph : separate k and v indices
ggml-ci
1 parent 253304a commit 3d930a9

File tree

4 files changed

+152
-57
lines changed

4 files changed

+152
-57
lines changed

src/llama-graph.cpp

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284-
if (self_kv_idxs) {
285-
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
284+
if (self_k_idxs) {
285+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
286+
}
287+
288+
if (self_v_idxs) {
289+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
286290
}
287291

288292
if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291295
}
292296

293297
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
294-
if (self_kv_idxs) {
295-
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
298+
if (self_k_idxs) {
299+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
300+
}
301+
302+
if (self_v_idxs) {
303+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
296304
}
297305

298-
if (self_kv_idxs_swa) {
299-
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
306+
if (self_k_idxs_swa) {
307+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
308+
}
309+
310+
if (self_v_idxs_swa) {
311+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
300312
}
301313

302314
if (self_kq_mask) {
@@ -345,6 +357,14 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
345357
}
346358

347359
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
360+
if (self_k_idxs) {
361+
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
362+
}
363+
364+
if (self_v_idxs) {
365+
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
366+
}
367+
348368
if (self_kq_mask) {
349369
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
350370
}
@@ -362,7 +382,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
362382
}
363383
}
364384

365-
void llm_graph_input_one::set_input(const llama_ubatch *) {
385+
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
386+
GGML_UNUSED(ubatch);
366387
GGML_ASSERT(one && ggml_nelements(one) == 1);
367388
float f_one = 1.0f;
368389
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -1009,6 +1030,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10091030

10101031
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
10111032

1033+
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
1034+
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
1035+
10121036
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
10131037
//cb(inp->self_kq_mask, "KQ_mask", -1);
10141038
ggml_set_input(inp->self_kq_mask);
@@ -1210,11 +1234,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12101234

12111235
const auto n_kv = mctx_cur->get_n_kv();
12121236

1213-
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1214-
ggml_set_input(inp->self_kv_idxs);
1237+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1238+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12151239

12161240
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1217-
//cb(inp->self_kq_mask, "KQ_mask", -1);
12181241
ggml_set_input(inp->self_kq_mask);
12191242

12201243
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1245,10 +1268,11 @@ ggml_tensor * llm_graph_context::build_attn(
12451268

12461269
// store to KV cache
12471270
{
1248-
const auto & kv_idxs = inp->get_kv_idxs();
1271+
const auto & k_idxs = inp->get_k_idxs();
1272+
const auto & v_idxs = inp->get_v_idxs();
12491273

1250-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1251-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
1274+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1275+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
12521276
}
12531277

12541278
const auto & kq_mask = inp->get_kq_mask();
@@ -1307,15 +1331,15 @@ ggml_tensor * llm_graph_context::build_attn(
13071331

13081332
// optionally store to KV cache
13091333
if (k_cur) {
1310-
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1334+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
13111335

1312-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1336+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
13131337
}
13141338

13151339
if (v_cur) {
1316-
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1340+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
13171341

1318-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
1342+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
13191343
}
13201344

13211345
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1419,8 +1443,11 @@ ggml_tensor * llm_graph_context::build_attn(
14191443

14201444
// store to KV cache
14211445
{
1422-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
1423-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
1446+
const auto & k_idxs = inp->get_k_idxs();
1447+
const auto & v_idxs = inp->get_v_idxs();
1448+
1449+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1450+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
14241451
}
14251452

14261453
const auto & kq_mask = inp->get_kq_mask();
@@ -1455,11 +1482,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14551482
{
14561483
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14571484

1458-
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1459-
ggml_set_input(inp->self_kv_idxs);
1485+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1486+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14601487

14611488
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1462-
//cb(inp->self_kq_mask, "KQ_mask", -1);
14631489
ggml_set_input(inp->self_kq_mask);
14641490

14651491
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1470,11 +1496,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14701496

14711497
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14721498

1473-
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1474-
ggml_set_input(inp->self_kv_idxs_swa);
1499+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1500+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14751501

14761502
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1477-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14781503
ggml_set_input(inp->self_kq_mask_swa);
14791504

14801505
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,13 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
248248

249249
void set_input(const llama_ubatch * ubatch) override;
250250

251-
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
251+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
252+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
253+
252254
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253255

254-
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
256+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
257+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
255258

256259
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
257260
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
@@ -276,13 +279,18 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
276279

277280
void set_input(const llama_ubatch * ubatch) override;
278281

279-
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
280-
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
282+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
283+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
284+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
285+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
286+
281287
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
282288
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
283289

284-
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
285-
ggml_tensor * self_kv_idxs_swa = nullptr; // I64 [n_batch]
290+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
291+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
292+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
293+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
286294

287295
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
288296
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
@@ -326,8 +334,14 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
326334

327335
ggml_tensor * s_copy; // I32 [kv_size]
328336

337+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
338+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
339+
329340
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
330341

342+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
343+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
344+
331345
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
332346
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
333347

@@ -343,7 +357,7 @@ class llm_graph_input_one : public llm_graph_input_i {
343357
llm_graph_input_one() {}
344358
virtual ~llm_graph_input_one() = default;
345359

346-
void set_input(const llama_ubatch *) override;
360+
void set_input(const llama_ubatch * ubatch) override;
347361

348362
ggml_tensor * one = nullptr; // F32
349363
};

src/llama-kv-cache-unified.cpp

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808808
0);
809809
}
810810

811-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
811+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
812812
const int32_t ikv = map_layer_ids.at(il);
813813

814814
auto * k = layers[ikv].k;
@@ -818,8 +818,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
818818

819819
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
820820

821-
if (kv_idxs && supports_set_rows) {
822-
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
821+
if (k_idxs && supports_set_rows) {
822+
return ggml_set_rows(ctx, k, k_cur, k_idxs);
823823
}
824824

825825
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -832,7 +832,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
832832
return ggml_cpy(ctx, k_cur, k_view);
833833
}
834834

835-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
835+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
836836
const int32_t ikv = map_layer_ids.at(il);
837837

838838
auto * v = layers[ikv].v;
@@ -842,9 +842,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
842842

843843
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
844844

845-
if (kv_idxs && supports_set_rows) {
845+
if (v_idxs && supports_set_rows) {
846846
if (!v_trans) {
847-
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
847+
return ggml_set_rows(ctx, v, v_cur, v_idxs);
848848
}
849849

850850
// the row becomes a single element
@@ -859,10 +859,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
859859
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
860860

861861
// we broadcast the KV indices n_embd_v_gqa times
862-
// v [1, n_kv, n_embd_v_gqa]
863-
// v_cur [1, n_tokens, n_embd_v_gqa]
864-
// kv_idxs [n_tokens, 1, 1]
865-
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
862+
// v [1, n_kv, n_embd_v_gqa]
863+
// v_cur [1, n_tokens, n_embd_v_gqa]
864+
// v_idxs [n_tokens, 1, 1]
865+
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
866866
}
867867

868868
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -885,7 +885,42 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
885885
return ggml_cpy(ctx, v_cur, v_view);
886886
}
887887

888-
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
888+
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
889+
const uint32_t n_tokens = ubatch.n_tokens;
890+
891+
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
892+
893+
ggml_set_input(k_idxs);
894+
895+
return k_idxs;
896+
}
897+
898+
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
899+
const uint32_t n_tokens = ubatch.n_tokens;
900+
901+
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
902+
903+
ggml_set_input(v_idxs);
904+
905+
return v_idxs;
906+
}
907+
908+
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
909+
if (!supports_set_rows) {
910+
return;
911+
}
912+
913+
const uint32_t n_tokens = ubatch->n_tokens;
914+
915+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
916+
int64_t * data = (int64_t *) dst->data;
917+
918+
for (int64_t i = 0; i < n_tokens; ++i) {
919+
data[i] = sinfo.idxs[i];
920+
}
921+
}
922+
923+
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
889924
if (!supports_set_rows) {
890925
return;
891926
}
@@ -1906,20 +1941,32 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
19061941
return kv->get_v(ctx, il, n_kv);
19071942
}
19081943

1909-
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1910-
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
1944+
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1945+
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1946+
}
1947+
1948+
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1949+
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
19111950
}
19121951

1913-
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1914-
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
1952+
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1953+
return kv->build_input_k_idxs(ctx, ubatch);
1954+
}
1955+
1956+
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1957+
return kv->build_input_v_idxs(ctx, ubatch);
19151958
}
19161959

19171960
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
19181961
kv->set_input_k_shift(dst);
19191962
}
19201963

1921-
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1922-
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
1964+
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1965+
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
1966+
}
1967+
1968+
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1969+
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
19231970
}
19241971

19251972
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {

0 commit comments

Comments
 (0)