Skip to content

Commit c246784

Browse files
committed
kv-cache : use ggml_set_rows
ggml-ci
1 parent 8d94219 commit c246784

9 files changed

+335
-123
lines changed

src/llama-graph.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284+
if (self_kv_idxs) {
285+
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
286+
}
287+
284288
if (self_kq_mask) {
285289
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286290
}
287291
}
288292

289293
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
294+
if (self_kv_idxs) {
295+
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
296+
}
297+
298+
if (self_kv_idxs_swa) {
299+
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
300+
}
301+
290302
if (self_kq_mask) {
291303
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292304
}
@@ -1198,6 +1210,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11981210

11991211
const auto n_kv = mctx_cur->get_n_kv();
12001212

1213+
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1214+
ggml_set_input(inp->self_kv_idxs);
1215+
12011216
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12021217
//cb(inp->self_kq_mask, "KQ_mask", -1);
12031218
ggml_set_input(inp->self_kq_mask);
@@ -1230,8 +1245,10 @@ ggml_tensor * llm_graph_context::build_attn(
12301245

12311246
// store to KV cache
12321247
{
1233-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1248+
const auto & kv_idxs = inp->get_kv_idxs();
1249+
1250+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1251+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
12351252
}
12361253

12371254
const auto & kq_mask = inp->get_kq_mask();
@@ -1290,11 +1307,15 @@ ggml_tensor * llm_graph_context::build_attn(
12901307

12911308
// optionally store to KV cache
12921309
if (k_cur) {
1293-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1310+
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1311+
1312+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
12941313
}
12951314

12961315
if (v_cur) {
1297-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1316+
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1317+
1318+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
12981319
}
12991320

13001321
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1398,8 +1419,8 @@ ggml_tensor * llm_graph_context::build_attn(
13981419

13991420
// store to KV cache
14001421
{
1401-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1422+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
1423+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
14031424
}
14041425

14051426
const auto & kq_mask = inp->get_kq_mask();
@@ -1434,6 +1455,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14341455
{
14351456
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14361457

1458+
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1459+
ggml_set_input(inp->self_kv_idxs);
1460+
14371461
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
14381462
//cb(inp->self_kq_mask, "KQ_mask", -1);
14391463
ggml_set_input(inp->self_kq_mask);
@@ -1446,6 +1470,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14461470

14471471
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14481472

1473+
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1474+
ggml_set_input(inp->self_kv_idxs_swa);
1475+
14491476
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
14501477
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14511478
ggml_set_input(inp->self_kq_mask_swa);

src/llama-graph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,11 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
248248

249249
void set_input(const llama_ubatch * ubatch) override;
250250

251+
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
251252
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
252253

254+
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
255+
253256
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
254257
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
255258

@@ -273,9 +276,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
273276

274277
void set_input(const llama_ubatch * ubatch) override;
275278

279+
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
280+
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
276281
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
277282
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
278283

284+
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
285+
ggml_tensor * self_kv_idxs_swa = nullptr; // I64 [n_batch]
286+
279287
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
280288
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
281289
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113
ubatches.push_back(std::move(ubatch)); // NOLINT
114114
}
115115

116-
auto heads_base = kv_base->prepare(ubatches);
117-
if (heads_base.empty()) {
116+
auto sinfos_base = kv_base->prepare(ubatches);
117+
if (sinfos_base.empty()) {
118118
break;
119119
}
120120

121-
auto heads_swa = kv_swa->prepare(ubatches);
122-
if (heads_swa.empty()) {
121+
auto sinfos_swa = kv_swa->prepare(ubatches);
122+
if (sinfos_swa.empty()) {
123123
break;
124124
}
125125

126-
assert(heads_base.size() == heads_swa.size());
126+
assert(sinfos_base.size() == sinfos_swa.size());
127127

128128
return std::make_unique<llama_kv_cache_unified_iswa_context>(
129-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
129+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
130130
} while (false);
131131

132132
// if it fails, try equal split
@@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144144
ubatches.push_back(std::move(ubatch)); // NOLINT
145145
}
146146

147-
auto heads_base = kv_base->prepare(ubatches);
148-
if (heads_base.empty()) {
147+
auto sinfos_base = kv_base->prepare(ubatches);
148+
if (sinfos_base.empty()) {
149149
break;
150150
}
151151

152-
auto heads_swa = kv_swa->prepare(ubatches);
153-
if (heads_swa.empty()) {
152+
auto sinfos_swa = kv_swa->prepare(ubatches);
153+
if (sinfos_swa.empty()) {
154154
break;
155155
}
156156

157-
assert(heads_base.size() == heads_swa.size());
157+
assert(sinfos_base.size() == sinfos_swa.size());
158158

159159
return std::make_unique<llama_kv_cache_unified_iswa_context>(
160-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
160+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
161161
} while (false);
162162

163163
// TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
220220

221221
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222222
llama_kv_cache_unified_iswa * kv,
223-
std::vector<uint32_t> heads_base,
224-
std::vector<uint32_t> heads_swa,
223+
slot_info_vec_t sinfos_base,
224+
slot_info_vec_t sinfos_swa,
225225
std::vector<llama_ubatch> ubatches) :
226226
ubatches(std::move(ubatches)),
227227
// note: here we copy the ubatches. not sure if this is ideal
228-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
228+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
229+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
230230
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231231
}
232232

src/llama-kv-cache-unified-iswa.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
7474

7575
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
7676
public:
77+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
78+
7779
// used for errors
7880
llama_kv_cache_unified_iswa_context(llama_memory_status status);
7981

@@ -90,8 +92,8 @@ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
9092
// used to create a batch processing context from a batch
9193
llama_kv_cache_unified_iswa_context(
9294
llama_kv_cache_unified_iswa * kv,
93-
std::vector<uint32_t> heads_base,
94-
std::vector<uint32_t> heads_swa,
95+
slot_info_vec_t sinfos_base,
96+
slot_info_vec_t sinfos_swa,
9597
std::vector<llama_ubatch> ubatches);
9698

9799
virtual ~llama_kv_cache_unified_iswa_context();

0 commit comments

Comments
 (0)