Skip to content

Commit 06bb08a

Browse files
committed
graph : separate k and v indices
ggml-ci
1 parent 253304a commit 06bb08a

File tree

4 files changed

+140
-59
lines changed

4 files changed

+140
-59
lines changed

src/llama-graph.cpp

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284-
if (self_kv_idxs) {
285-
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
284+
if (self_k_idxs) {
285+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
286+
}
287+
288+
if (self_v_idxs) {
289+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
286290
}
287291

288292
if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291295
}
292296

293297
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
294-
if (self_kv_idxs) {
295-
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
298+
if (self_k_idxs) {
299+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
300+
}
301+
302+
if (self_v_idxs) {
303+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
296304
}
297305

298-
if (self_kv_idxs_swa) {
299-
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
306+
if (self_k_idxs_swa) {
307+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
308+
}
309+
310+
if (self_v_idxs_swa) {
311+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
300312
}
301313

302314
if (self_kq_mask) {
@@ -1210,11 +1222,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12101222

12111223
const auto n_kv = mctx_cur->get_n_kv();
12121224

1213-
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1214-
ggml_set_input(inp->self_kv_idxs);
1225+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1226+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12151227

12161228
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1217-
//cb(inp->self_kq_mask, "KQ_mask", -1);
12181229
ggml_set_input(inp->self_kq_mask);
12191230

12201231
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1245,10 +1256,11 @@ ggml_tensor * llm_graph_context::build_attn(
12451256

12461257
// store to KV cache
12471258
{
1248-
const auto & kv_idxs = inp->get_kv_idxs();
1259+
const auto & k_idxs = inp->get_k_idxs();
1260+
const auto & v_idxs = inp->get_v_idxs();
12491261

1250-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1251-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
1262+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1263+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
12521264
}
12531265

12541266
const auto & kq_mask = inp->get_kq_mask();
@@ -1307,15 +1319,15 @@ ggml_tensor * llm_graph_context::build_attn(
13071319

13081320
// optionally store to KV cache
13091321
if (k_cur) {
1310-
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1322+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
13111323

1312-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1324+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
13131325
}
13141326

13151327
if (v_cur) {
1316-
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1328+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
13171329

1318-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
1330+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
13191331
}
13201332

13211333
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1455,11 +1467,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14551467
{
14561468
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14571469

1458-
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1459-
ggml_set_input(inp->self_kv_idxs);
1470+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1471+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14601472

14611473
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1462-
//cb(inp->self_kq_mask, "KQ_mask", -1);
14631474
ggml_set_input(inp->self_kq_mask);
14641475

14651476
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1470,11 +1481,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14701481

14711482
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14721483

1473-
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1474-
ggml_set_input(inp->self_kv_idxs_swa);
1484+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1485+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14751486

14761487
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1477-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14781488
ggml_set_input(inp->self_kq_mask_swa);
14791489

14801490
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,16 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
248248

249249
void set_input(const llama_ubatch * ubatch) override;
250250

251-
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
251+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
252+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
253+
252254
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253255

254-
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
256+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
257+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
255258

256-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
257-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
259+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
260+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
258261

259262
const llama_hparams & hparams;
260263
const llama_cparams & cparams;
@@ -276,18 +279,23 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
276279

277280
void set_input(const llama_ubatch * ubatch) override;
278281

279-
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
280-
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
282+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
283+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
284+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
285+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
286+
281287
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
282288
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
283289

284-
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
285-
ggml_tensor * self_kv_idxs_swa = nullptr; // I64 [n_batch]
290+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
291+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
292+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
293+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
286294

287-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
288-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
289-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
290-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
295+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
296+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
297+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
298+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
291299

292300
const llama_hparams & hparams;
293301
const llama_cparams & cparams;

src/llama-kv-cache-unified.cpp

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808808
0);
809809
}
810810

811-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
811+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
812812
const int32_t ikv = map_layer_ids.at(il);
813813

814814
auto * k = layers[ikv].k;
@@ -818,8 +818,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
818818

819819
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
820820

821-
if (kv_idxs && supports_set_rows) {
822-
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
821+
if (k_idxs && supports_set_rows) {
822+
return ggml_set_rows(ctx, k, k_cur, k_idxs);
823823
}
824824

825825
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -832,7 +832,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
832832
return ggml_cpy(ctx, k_cur, k_view);
833833
}
834834

835-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
835+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
836836
const int32_t ikv = map_layer_ids.at(il);
837837

838838
auto * v = layers[ikv].v;
@@ -842,9 +842,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
842842

843843
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
844844

845-
if (kv_idxs && supports_set_rows) {
845+
if (v_idxs && supports_set_rows) {
846846
if (!v_trans) {
847-
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
847+
return ggml_set_rows(ctx, v, v_cur, v_idxs);
848848
}
849849

850850
// the row becomes a single element
@@ -859,10 +859,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
859859
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
860860

861861
// we broadcast the KV indices n_embd_v_gqa times
862-
// v [1, n_kv, n_embd_v_gqa]
863-
// v_cur [1, n_tokens, n_embd_v_gqa]
864-
// kv_idxs [n_tokens, 1, 1]
865-
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
862+
// v [1, n_kv, n_embd_v_gqa]
863+
// v_cur [1, n_tokens, n_embd_v_gqa]
864+
// v_idxs [n_tokens, 1, 1]
865+
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
866866
}
867867

868868
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -885,7 +885,49 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
885885
return ggml_cpy(ctx, v_cur, v_view);
886886
}
887887

888-
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
888+
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
889+
const uint32_t n_tokens = ubatch.n_tokens;
890+
891+
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
892+
893+
ggml_set_input(k_idxs);
894+
895+
return k_idxs;
896+
}
897+
898+
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
899+
const uint32_t n_tokens = ubatch.n_tokens;
900+
901+
ggml_tensor * v_idxs;
902+
903+
if (!v_trans) {
904+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
905+
} else {
906+
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
907+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa());
908+
}
909+
910+
ggml_set_input(v_idxs);
911+
912+
return v_idxs;
913+
}
914+
915+
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
916+
if (!supports_set_rows) {
917+
return;
918+
}
919+
920+
const uint32_t n_tokens = ubatch->n_tokens;
921+
922+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
923+
int64_t * data = (int64_t *) dst->data;
924+
925+
for (int64_t i = 0; i < n_tokens; ++i) {
926+
data[i] = sinfo.idxs[i];
927+
}
928+
}
929+
930+
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
889931
if (!supports_set_rows) {
890932
return;
891933
}
@@ -1906,20 +1948,32 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
19061948
return kv->get_v(ctx, il, n_kv);
19071949
}
19081950

1909-
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1910-
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
1951+
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1952+
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1953+
}
1954+
1955+
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1956+
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
19111957
}
19121958

1913-
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1914-
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
1959+
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1960+
return kv->build_input_k_idxs(ctx, ubatch);
1961+
}
1962+
1963+
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1964+
return kv->build_input_v_idxs(ctx, ubatch);
19151965
}
19161966

19171967
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
19181968
kv->set_input_k_shift(dst);
19191969
}
19201970

1921-
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1922-
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
1971+
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1972+
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
1973+
}
1974+
1975+
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1976+
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
19231977
}
19241978

19251979
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {

src/llama-kv-cache-unified.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class llama_kv_cache_unified : public llama_memory_i {
124124
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
125125

126126
// store k_cur and v_cur in the cache based on the provided head location
127-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
128-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
127+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
128+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
129129

130130
//
131131
// preparation API
@@ -146,10 +146,15 @@ class llama_kv_cache_unified : public llama_memory_i {
146146
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
147147

148148
//
149-
// set_input API
149+
// input API
150150
//
151151

152-
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
152+
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
153+
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
154+
155+
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
156+
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
157+
153158
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
154159
void set_input_k_shift (ggml_tensor * dst) const;
155160
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -286,12 +291,16 @@ class llama_kv_cache_unified_context : public llama_memory_context_i {
286291
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
287292

288293
// store k_cur and v_cur in the cache based on the provided head location
289-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const;
290-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const;
294+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
295+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
291296

292-
void set_input_k_shift(ggml_tensor * dst) const;
297+
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
298+
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
293299

294-
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const;
300+
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
301+
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
302+
303+
void set_input_k_shift (ggml_tensor * dst) const;
295304
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
296305
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
297306

0 commit comments

Comments
 (0)