@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
281
}
282
282
283
283
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284
- if (self_kv_idxs) {
285
- mctx->set_input_kv_idxs (self_kv_idxs, ubatch);
284
+ if (self_k_idxs) {
285
+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
286
+ }
287
+
288
+ if (self_v_idxs) {
289
+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
286
290
}
287
291
288
292
if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291
295
}
292
296
293
297
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294
- if (self_kv_idxs) {
295
- mctx->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
298
+ if (self_k_idxs) {
299
+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
300
+ }
301
+
302
+ if (self_v_idxs) {
303
+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
296
304
}
297
305
298
- if (self_kv_idxs_swa) {
299
- mctx->get_swa ()->set_input_kv_idxs (self_kv_idxs_swa, ubatch);
306
+ if (self_k_idxs_swa) {
307
+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
308
+ }
309
+
310
+ if (self_v_idxs_swa) {
311
+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
300
312
}
301
313
302
314
if (self_kq_mask) {
@@ -345,6 +357,14 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
345
357
}
346
358
347
359
void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
360
+ if (self_k_idxs) {
361
+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
362
+ }
363
+
364
+ if (self_v_idxs) {
365
+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
366
+ }
367
+
348
368
if (self_kq_mask) {
349
369
mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
350
370
}
@@ -362,7 +382,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
362
382
}
363
383
}
364
384
365
- void llm_graph_input_one::set_input (const llama_ubatch *) {
385
+ void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
386
+ GGML_UNUSED (ubatch);
366
387
GGML_ASSERT (one && ggml_nelements (one) == 1 );
367
388
float f_one = 1 .0f ;
368
389
ggml_backend_tensor_set (one, &f_one, 0 , sizeof (float ));
@@ -1009,6 +1030,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1009
1030
1010
1031
const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1011
1032
1033
+ inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1034
+ inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1035
+
1012
1036
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1013
1037
// cb(inp->self_kq_mask, "KQ_mask", -1);
1014
1038
ggml_set_input (inp->self_kq_mask );
@@ -1210,11 +1234,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1210
1234
1211
1235
const auto n_kv = mctx_cur->get_n_kv ();
1212
1236
1213
- inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens );
1214
- ggml_set_input ( inp->self_kv_idxs );
1237
+ inp->self_k_idxs = mctx_cur-> build_input_k_idxs (ctx0, ubatch );
1238
+ inp->self_v_idxs = mctx_cur-> build_input_v_idxs (ctx0, ubatch );
1215
1239
1216
1240
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1217
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1218
1241
ggml_set_input (inp->self_kq_mask );
1219
1242
1220
1243
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1245,10 +1268,11 @@ ggml_tensor * llm_graph_context::build_attn(
1245
1268
1246
1269
// store to KV cache
1247
1270
{
1248
- const auto & kv_idxs = inp->get_kv_idxs ();
1271
+ const auto & k_idxs = inp->get_k_idxs ();
1272
+ const auto & v_idxs = inp->get_v_idxs ();
1249
1273
1250
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1251
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1274
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1275
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
1252
1276
}
1253
1277
1254
1278
const auto & kq_mask = inp->get_kq_mask ();
@@ -1307,15 +1331,15 @@ ggml_tensor * llm_graph_context::build_attn(
1307
1331
1308
1332
// optionally store to KV cache
1309
1333
if (k_cur) {
1310
- const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1334
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1311
1335
1312
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1336
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1313
1337
}
1314
1338
1315
1339
if (v_cur) {
1316
- const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1340
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
1317
1341
1318
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1342
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
1319
1343
}
1320
1344
1321
1345
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1419,8 +1443,11 @@ ggml_tensor * llm_graph_context::build_attn(
1419
1443
1420
1444
// store to KV cache
1421
1445
{
1422
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, nullptr , il));
1423
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, nullptr , il));
1446
+ const auto & k_idxs = inp->get_k_idxs ();
1447
+ const auto & v_idxs = inp->get_v_idxs ();
1448
+
1449
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1450
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
1424
1451
}
1425
1452
1426
1453
const auto & kq_mask = inp->get_kq_mask ();
@@ -1455,11 +1482,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1455
1482
{
1456
1483
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1457
1484
1458
- inp->self_kv_idxs = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1459
- ggml_set_input ( inp->self_kv_idxs );
1485
+ inp->self_k_idxs = mctx_cur-> get_base ()-> build_input_k_idxs ( ctx0, ubatch );
1486
+ inp->self_v_idxs = mctx_cur-> get_base ()-> build_input_v_idxs (ctx0, ubatch );
1460
1487
1461
1488
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1462
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1463
1489
ggml_set_input (inp->self_kq_mask );
1464
1490
1465
1491
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1470,11 +1496,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1470
1496
1471
1497
const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
1472
1498
1473
- inp->self_kv_idxs_swa = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1474
- ggml_set_input ( inp->self_kv_idxs_swa );
1499
+ inp->self_k_idxs_swa = mctx_cur-> get_swa ()-> build_input_k_idxs ( ctx0, ubatch );
1500
+ inp->self_v_idxs_swa = mctx_cur-> get_swa ()-> build_input_v_idxs (ctx0, ubatch );
1475
1501
1476
1502
inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1477
- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1478
1503
ggml_set_input (inp->self_kq_mask_swa );
1479
1504
1480
1505
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments