@@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
281
}
282
282
283
283
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284
+ if (self_kv_idxs) {
285
+ mctx->set_input_kv_idxs (self_kv_idxs, ubatch);
286
+ }
287
+
284
288
if (self_kq_mask) {
285
289
mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
286
290
}
287
291
}
288
292
289
293
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294
+ if (self_kv_idxs) {
295
+ mctx->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
296
+ }
297
+
298
+ if (self_kv_idxs_swa) {
299
+ mctx->get_swa ()->set_input_kv_idxs (self_kv_idxs_swa, ubatch);
300
+ }
301
+
290
302
if (self_kq_mask) {
291
303
mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
292
304
}
@@ -1198,6 +1210,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1198
1210
1199
1211
const auto n_kv = mctx_cur->get_n_kv ();
1200
1212
1213
+ inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1214
+ ggml_set_input (inp->self_kv_idxs );
1215
+
1201
1216
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1202
1217
// cb(inp->self_kq_mask, "KQ_mask", -1);
1203
1218
ggml_set_input (inp->self_kq_mask );
@@ -1230,8 +1245,10 @@ ggml_tensor * llm_graph_context::build_attn(
1230
1245
1231
1246
// store to KV cache
1232
1247
{
1233
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1234
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1248
+ const auto & kv_idxs = inp->get_kv_idxs ();
1249
+
1250
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs, il));
1251
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs, il));
1235
1252
}
1236
1253
1237
1254
const auto & kq_mask = inp->get_kq_mask ();
@@ -1290,11 +1307,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
1307
1291
1308
// optionally store to KV cache
1292
1309
if (k_cur) {
1293
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1310
+ const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1311
+
1312
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs, il));
1294
1313
}
1295
1314
1296
1315
if (v_cur) {
1297
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1316
+ const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1317
+
1318
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs, il));
1298
1319
}
1299
1320
1300
1321
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1398,8 +1419,8 @@ ggml_tensor * llm_graph_context::build_attn(
1398
1419
1399
1420
// store to KV cache
1400
1421
{
1401
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1402
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1422
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, nullptr , il));
1423
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, nullptr , il));
1403
1424
}
1404
1425
1405
1426
const auto & kq_mask = inp->get_kq_mask ();
@@ -1434,6 +1455,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1434
1455
{
1435
1456
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1436
1457
1458
+ inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1459
+ ggml_set_input (inp->self_kv_idxs );
1460
+
1437
1461
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1438
1462
// cb(inp->self_kq_mask, "KQ_mask", -1);
1439
1463
ggml_set_input (inp->self_kq_mask );
@@ -1446,6 +1470,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1446
1470
1447
1471
const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
1448
1472
1473
+ inp->self_kv_idxs_swa = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1474
+ ggml_set_input (inp->self_kv_idxs_swa );
1475
+
1449
1476
inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1450
1477
// cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1451
1478
ggml_set_input (inp->self_kq_mask_swa );
0 commit comments