Skip to content

Commit a86b929

Browse files
committed
kv-cache : perform stream copies lazily after llama_synchronize
ggml-ci
1 parent a823406 commit a86b929

File tree

2 files changed

+54
-17
lines changed

2 files changed

+54
-17
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,9 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
311311

312312
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
313313

314-
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (stream = %d) to %d (stream = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
315-
316-
for (uint32_t il = 0; il < layers.size(); ++il) {
317-
const auto & layer = layers[il];
318-
319-
ggml_backend_tensor_copy(layer.k_stream[s0], layer.k_stream[s1]);
320-
ggml_backend_tensor_copy(layer.v_stream[s0], layer.v_stream[s1]);
321-
}
314+
// enqueue the copy operation - the buffer copy will be performed during the next update
315+
sc_info.ssrc.push_back(s0);
316+
sc_info.sdst.push_back(s1);
322317

323318
v_cells[s1].reset();
324319
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
@@ -526,7 +521,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
526521
}
527522
}
528523

529-
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
524+
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
530525
}
531526

532527
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -598,11 +593,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
598593
return res;
599594
}
600595

601-
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
596+
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
602597
bool updated = false;
603598

604599
auto * sched = lctx->get_sched();
605600

601+
if (!sc_info.empty()) {
602+
assert(n_stream > 1 && "stream copy should never happen with a single stream");
603+
604+
llama_synchronize(lctx);
605+
606+
const size_t n_copy = sc_info.ssrc.size();
607+
608+
for (size_t i = 0; i < n_copy; ++i) {
609+
const auto ssrc = sc_info.ssrc.at(i);
610+
const auto sdst = sc_info.sdst.at(i);
611+
612+
LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
613+
614+
assert(ssrc != sdst);
615+
616+
for (uint32_t il = 0; il < layers.size(); ++il) {
617+
const auto & layer = layers[il];
618+
619+
ggml_backend_tensor_copy(layer.k_stream.at(ssrc), layer.k_stream.at(sdst));
620+
ggml_backend_tensor_copy(layer.v_stream.at(ssrc), layer.v_stream.at(sdst));
621+
}
622+
}
623+
}
624+
606625
if (do_shift) {
607626
if (!get_can_shift()) {
608627
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -2242,8 +2261,9 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
22422261
llama_kv_cache_unified * kv,
22432262
llama_context * lctx,
22442263
bool do_shift,
2245-
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
2246-
if (!do_shift && this->dinfo.empty()) {
2264+
defrag_info dinfo,
2265+
stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
2266+
if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
22472267
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
22482268
}
22492269
}
@@ -2271,7 +2291,7 @@ bool llama_kv_cache_unified_context::apply() {
22712291

22722292
// no ubatches -> this is a KV cache update
22732293
if (ubatches.empty()) {
2274-
kv->update(lctx, do_shift, dinfo);
2294+
kv->update(lctx, do_shift, dinfo, sc_info);
22752295

22762296
return true;
22772297
}

src/llama-kv-cache-unified.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ class llama_kv_cache_unified : public llama_memory_i {
3535
std::vector<uint32_t> ids;
3636
};
3737

38+
struct stream_copy_info {
39+
bool empty() const {
40+
assert(ssrc.size() == sdst.size());
41+
return ssrc.empty();
42+
}
43+
44+
std::vector<uint32_t> ssrc;
45+
std::vector<uint32_t> sdst;
46+
};
47+
3848
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
3949
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
4050
struct slot_info {
@@ -158,7 +168,7 @@ class llama_kv_cache_unified : public llama_memory_i {
158168
// return empty vector on failure
159169
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
160170

161-
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
171+
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
162172

163173
// find a slot of kv cells that can hold the ubatch
164174
// if cont == true, then the slot must be continuous
@@ -231,6 +241,9 @@ class llama_kv_cache_unified : public llama_memory_i {
231241
// maps from a sequence id to a stream id
232242
std::vector<uint32_t> seq_to_stream;
233243

244+
// pending stream copies that will be applied during the next update
245+
stream_copy_info sc_info;
246+
234247
std::vector<kv_layer> layers;
235248

236249
// model layer id -> KV cache layer id
@@ -282,8 +295,9 @@ class llama_kv_cache_unified : public llama_memory_i {
282295
class llama_kv_cache_unified_context : public llama_memory_context_i {
283296
public:
284297
// some shorthands
285-
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
286-
using defrag_info = llama_kv_cache_unified::defrag_info;
298+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
299+
using defrag_info = llama_kv_cache_unified::defrag_info;
300+
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
287301

288302
// used for errors
289303
llama_kv_cache_unified_context(llama_memory_status status);
@@ -297,7 +311,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i {
297311
llama_kv_cache_unified * kv,
298312
llama_context * lctx,
299313
bool do_shift,
300-
defrag_info dinfo);
314+
defrag_info dinfo,
315+
stream_copy_info sc_info);
301316

302317
// used to create a batch procesing context from a batch
303318
llama_kv_cache_unified_context(
@@ -355,6 +370,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i {
355370

356371
defrag_info dinfo;
357372

373+
stream_copy_info sc_info;
374+
358375
//
359376
// batch processing context
360377
//

0 commit comments

Comments
 (0)