Skip to content

Commit faa6071

Browse files
ggerganovam17an
authored andcommitted
llama : reuse compute graphs
ggml-ci
1 parent 99af79e commit faa6071

17 files changed

+456
-185
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14641464
params.swa_full = true;
14651465
}
14661466
).set_env("LLAMA_ARG_SWA_FULL"));
1467+
add_opt(common_arg(
1468+
{"--graph-reuse", "-gr"},
1469+
string_format("reuse previous compute graphs when possible (default: %s)"
1470+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14482)", params.graph_reuse ? "true" : "false"),
1471+
[](common_params & params) {
1472+
params.graph_reuse = true;
1473+
}
1474+
).set_env("LLAMA_ARG_GRAPH_REUSE"));
14671475
add_opt(common_arg(
14681476
{"--no-context-shift"},
14691477
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11571157
cparams.no_perf = params.no_perf;
11581158
cparams.op_offload = !params.no_op_offload;
11591159
cparams.swa_full = params.swa_full;
1160+
cparams.graph_reuse = params.graph_reuse;
11601161

11611162
cparams.type_k = params.cache_type_k;
11621163
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ struct common_params {
330330
bool no_perf = false; // disable performance metrics
331331
bool ctx_shift = true; // context shift on inifinite text generation
332332
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
333+
bool graph_reuse = false; // reuse previous compute graphs when possible
333334

334335
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
335336
bool use_mmap = true; // use mmap for faster loads

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ extern "C" {
374374
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
375375
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
376376
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
377+
378+
bool graph_reuse; // reuse previous compute graphs when possible
377379
};
378380

379381
// model quantization parameters
@@ -1429,6 +1431,7 @@ extern "C" {
14291431

14301432
int32_t n_p_eval;
14311433
int32_t n_eval;
1434+
int32_t n_reused;
14321435
};
14331436

14341437
struct llama_perf_sampler_data {

src/llama-batch.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ struct llama_ubatch {
3434
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
3535
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
3636
int8_t * output; // [n_tokens] | i | -
37+
38+
bool is_same(const llama_ubatch & other) const {
39+
bool res =
40+
equal_seqs == other.equal_seqs &&
41+
n_tokens == other.n_tokens &&
42+
n_seq_tokens == other.n_seq_tokens &&
43+
n_seqs == other.n_seqs &&
44+
n_seqs_unq == other.n_seqs_unq &&
45+
(
46+
(!token && !other.token) ||
47+
(!embd && !other.embd)
48+
);
49+
50+
if (!res) {
51+
return false;
52+
}
53+
54+
// TODO: this won't work because seq_id_unq ptr can point to an old balloc that has
55+
// been freed by this point. find a way to fix this
56+
//for (uint32_t s = 0; s < n_seqs_unq; ++s) {
57+
// res &= seq_id_unq[s] == other.seq_id_unq[s];
58+
//}
59+
60+
return res;
61+
}
3762
};
3863

3964
// a helper for sanitizing, fulfilling and splitting a batch

src/llama-context.cpp

Lines changed: 88 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ llama_context::llama_context(
101101

102102
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103103

104-
cparams.op_offload = params.op_offload;
104+
cparams.op_offload = params.op_offload;
105+
cparams.graph_reuse = params.graph_reuse;
105106

106107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
107108

@@ -227,8 +228,8 @@ llama_context::llama_context(
227228

228229
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229230

230-
// buffer used to store the computation graph and the tensor meta data
231-
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
231+
gf_res_prev.reset(new llm_graph_result(max_nodes));
232+
gf_res_reserve.reset(new llm_graph_result(max_nodes));
232233

233234
// TODO: move these checks to ggml_backend_sched
234235
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +389,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388389
return sched.get();
389390
}
390391

391-
ggml_context * llama_context::get_ctx_compute() const {
392-
return ctx_compute.get();
393-
}
394-
395392
uint32_t llama_context::n_ctx() const {
396393
return cparams.n_ctx;
397394
}
@@ -678,38 +675,52 @@ bool llama_context::apply_adapter_cvec(
678675
return cvec.apply(model, data, len, n_embd, il_start, il_end);
679676
}
680677

681-
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
678+
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682679
if (mctx && !mctx->apply()) {
683680
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684681
ret = GGML_STATUS_FAILED;
685682
return nullptr;
686683
}
687684

688-
auto * gf = graph_init();
689-
if (!gf) {
690-
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691-
ret = GGML_STATUS_FAILED;
692-
return nullptr;
693-
}
685+
auto * res = gf_res_prev.get();
686+
auto * gf = res->get_gf();
694687

695-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696-
if (!res) {
697-
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698-
ret = GGML_STATUS_FAILED;
699-
return nullptr;
700-
}
688+
// the new graph parameters
689+
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
690+
const auto gparams = graph_params(res, ubatch, mctx, gtype);
701691

702-
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
692+
const bool can_reuse = cparams.graph_reuse && res->update(gparams);
693+
if (can_reuse) {
694+
LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
695+
n_reused++;
696+
} else {
697+
res->reset();
703698

704-
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705-
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706-
ret = GGML_STATUS_ALLOC_FAILED;
707-
return nullptr;
699+
ggml_backend_sched_reset(sched.get());
700+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
701+
702+
//const auto t_start_us = ggml_time_us();
703+
704+
gf = model.build_graph(gparams);
705+
706+
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
707+
708+
if (!gf) {
709+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
710+
ret = GGML_STATUS_FAILED;
711+
return nullptr;
712+
}
713+
714+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
715+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
716+
ret = GGML_STATUS_ALLOC_FAILED;
717+
return nullptr;
718+
}
708719
}
709720

710721
res->set_inputs(&ubatch);
711722

712-
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
723+
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
713724
if (status != GGML_STATUS_SUCCESS) {
714725
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715726
ret = status;
@@ -767,9 +778,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767778

768779
n_outputs = n_tokens;
769780

770-
ggml_backend_sched_reset(sched.get());
771-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772-
773781
const auto causal_attn_org = cparams.causal_attn;
774782

775783
// always use non-causal attention for encoder graphs
@@ -778,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778786
cparams.causal_attn = false;
779787

780788
ggml_status status;
781-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
789+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782790

783791
cparams.causal_attn = causal_attn_org;
784792

@@ -846,7 +854,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
846854

847855
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848856
// overlap with device computation.
849-
ggml_backend_sched_reset(sched.get());
857+
if (!cparams.graph_reuse) {
858+
ggml_backend_sched_reset(sched.get());
859+
}
850860

851861
// TODO: hacky solution
852862
if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1015,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10051015
n_outputs = n_outputs_new;
10061016
}
10071017

1008-
ggml_backend_sched_reset(sched.get());
1009-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010-
10111018
ggml_status status;
1012-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1019+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
10131020

10141021
if (!res) {
10151022
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1199,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
11921199

11931200
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11941201
// overlap with device computation.
1195-
ggml_backend_sched_reset(sched.get());
1202+
if (!cparams.graph_reuse) {
1203+
ggml_backend_sched_reset(sched.get());
1204+
}
11961205

11971206
return 0;
11981207
}
@@ -1275,20 +1284,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12751284
// graph
12761285
//
12771286

1278-
int32_t llama_context::graph_max_nodes() const {
1279-
return std::max<int32_t>(65536, 5*model.n_tensors());
1280-
}
1281-
1282-
ggml_cgraph * llama_context::graph_init() {
1283-
ggml_init_params params = {
1284-
/*.mem_size =*/ buf_compute_meta.size(),
1285-
/*.mem_buffer =*/ buf_compute_meta.data(),
1286-
/*.no_alloc =*/ true,
1287-
};
1288-
1289-
ctx_compute.reset(ggml_init(params));
1290-
1291-
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1287+
uint32_t llama_context::graph_max_nodes() const {
1288+
return std::max<uint32_t>(65536u, 5u*model.n_tensors());
12921289
}
12931290

12941291
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1298,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13011298
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
13021299
}
13031300

1301+
gf_res_prev->reset();
1302+
ggml_backend_sched_reset(sched.get());
1303+
13041304
// store the n_outputs as it is, and restore it afterwards
13051305
// TODO: not sure if needed, might simplify in the future by removing this
13061306
const auto save_n_outputs = this->n_outputs;
@@ -1310,17 +1310,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101310
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
13111311
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
13121312

1313-
auto * gf = graph_init();
1314-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1313+
auto * res = gf_res_reserve.get();
13151314

1316-
this->n_outputs = save_n_outputs;
1315+
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
13171316

1318-
if (!res) {
1319-
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320-
return nullptr;
1321-
}
1317+
res->reset();
13221318

1323-
ggml_backend_sched_reset(sched.get());
1319+
auto * gf = model.build_graph(gparams);
1320+
1321+
this->n_outputs = save_n_outputs;
13241322

13251323
// initialize scheduler with the specified graph
13261324
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1331,28 +1329,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13311329
return gf;
13321330
}
13331331

1334-
llm_graph_result_ptr llama_context::graph_build(
1335-
ggml_context * ctx,
1336-
ggml_cgraph * gf,
1337-
const llama_ubatch & ubatch,
1338-
llm_graph_type gtype,
1339-
const llama_memory_context_i * mctx) {
1340-
return model.build_graph(
1341-
{
1342-
/*.ctx =*/ ctx,
1343-
/*.arch =*/ model.arch,
1344-
/*.hparams =*/ model.hparams,
1345-
/*.cparams =*/ cparams,
1346-
/*.ubatch =*/ ubatch,
1347-
/*.sched =*/ sched.get(),
1348-
/*.backend_cpu =*/ backend_cpu,
1349-
/*.cvec =*/ &cvec,
1350-
/*.loras =*/ &loras,
1351-
/*.mctx =*/ mctx,
1352-
/*.cross =*/ &cross,
1353-
/*.n_outputs =*/ n_outputs,
1354-
/*.cb =*/ graph_get_cb(),
1355-
}, gf, gtype);
1332+
llm_graph_params llama_context::graph_params(
1333+
llm_graph_result_i * res,
1334+
const llama_ubatch & ubatch,
1335+
const llama_memory_context_i * mctx,
1336+
llm_graph_type gtype) const {
1337+
return {
1338+
/*.arch =*/ model.arch,
1339+
/*.hparams =*/ model.hparams,
1340+
/*.cparams =*/ cparams,
1341+
/*.ubatch =*/ ubatch,
1342+
/*.gtype =*/ gtype,
1343+
/*.sched =*/ sched.get(),
1344+
/*.backend_cpu =*/ backend_cpu,
1345+
/*.cvec =*/ &cvec,
1346+
/*.loras =*/ &loras,
1347+
/*.mctx =*/ mctx,
1348+
/*.cross =*/ &cross,
1349+
/*.n_outputs =*/ n_outputs,
1350+
/*.cb =*/ graph_get_cb(),
1351+
/*.res =*/ res,
1352+
};
13561353
}
13571354

13581355
ggml_status llama_context::graph_compute(
@@ -1930,6 +1927,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
19301927
data.t_eval_ms = 1e-3 * t_eval_us;
19311928
data.n_p_eval = std::max(1, n_p_eval);
19321929
data.n_eval = std::max(1, n_eval);
1930+
data.n_reused = std::max(0, n_reused);
19331931

19341932
return data;
19351933
}
@@ -1938,6 +1936,7 @@ void llama_context::perf_reset() {
19381936
t_start_us = ggml_time_us();
19391937
t_eval_us = n_eval = 0;
19401938
t_p_eval_us = n_p_eval = 0;
1939+
n_reused = 0;
19411940
}
19421941

19431942
//
@@ -2064,8 +2063,13 @@ void llama_context::opt_epoch_iter(
20642063
break;
20652064
}
20662065

2067-
auto * gf = graph_init();
2068-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2066+
auto * res = gf_res_prev.get();
2067+
2068+
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2069+
2070+
res->reset();
2071+
2072+
auto * gf = model.build_graph(gparams);
20692073

20702074
struct ggml_context * ctx_compute_opt;
20712075
{
@@ -2187,6 +2191,7 @@ llama_context_params llama_context_default_params() {
21872191
/*.no_perf =*/ true,
21882192
/*.op_offload =*/ true,
21892193
/*.swa_full =*/ true,
2194+
/*.graph_reuse =*/ false,
21902195
};
21912196

21922197
return result;
@@ -2807,6 +2812,7 @@ void llama_perf_context_print(const llama_context * ctx) {
28072812
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
28082813
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
28092814
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2815+
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
28102816
}
28112817

28122818
void llama_perf_context_reset(llama_context * ctx) {

0 commit comments

Comments
 (0)