@@ -101,7 +101,8 @@ llama_context::llama_context(
101
101
102
102
cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
103
103
104
- cparams.op_offload = params.op_offload ;
104
+ cparams.op_offload = params.op_offload ;
105
+ cparams.graph_reuse = params.graph_reuse ;
105
106
106
107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
107
108
@@ -227,8 +228,8 @@ llama_context::llama_context(
227
228
228
229
LLAMA_LOG_DEBUG (" %s: max_nodes = %zu\n " , __func__, max_nodes);
229
230
230
- // buffer used to store the computation graph and the tensor meta data
231
- buf_compute_meta. resize ( ggml_tensor_overhead ()*max_nodes + ggml_graph_overhead_custom (max_nodes, false ));
231
+ gf_res_prev. reset ( new llm_graph_result (max_nodes));
232
+ gf_res_reserve. reset ( new llm_graph_result (max_nodes));
232
233
233
234
// TODO: move these checks to ggml_backend_sched
234
235
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +389,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388
389
return sched.get ();
389
390
}
390
391
391
- ggml_context * llama_context::get_ctx_compute () const {
392
- return ctx_compute.get ();
393
- }
394
-
395
392
uint32_t llama_context::n_ctx () const {
396
393
return cparams.n_ctx ;
397
394
}
@@ -678,38 +675,52 @@ bool llama_context::apply_adapter_cvec(
678
675
return cvec.apply (model, data, len, n_embd, il_start, il_end);
679
676
}
680
677
681
- llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
678
+ llm_graph_result_i * llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
679
if (mctx && !mctx->apply ()) {
683
680
LLAMA_LOG_ERROR (" %s: failed to apply memory context\n " , __func__);
684
681
ret = GGML_STATUS_FAILED;
685
682
return nullptr ;
686
683
}
687
684
688
- auto * gf = graph_init ();
689
- if (!gf) {
690
- LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
691
- ret = GGML_STATUS_FAILED;
692
- return nullptr ;
693
- }
685
+ auto * res = gf_res_prev.get ();
686
+ auto * gf = res->get_gf ();
694
687
695
- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mctx);
696
- if (!res) {
697
- LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
698
- ret = GGML_STATUS_FAILED;
699
- return nullptr ;
700
- }
688
+ // the new graph parameters
689
+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
690
+ const auto gparams = graph_params (res, ubatch, mctx, gtype);
701
691
702
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
692
+ const bool can_reuse = cparams.graph_reuse && res->update (gparams);
693
+ if (can_reuse) {
694
+ LLAMA_LOG_DEBUG (" %s: reusing previous graph\n " , __func__);
695
+ n_reused++;
696
+ } else {
697
+ res->reset ();
703
698
704
- if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
705
- LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
706
- ret = GGML_STATUS_ALLOC_FAILED;
707
- return nullptr ;
699
+ ggml_backend_sched_reset (sched.get ());
700
+ ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
701
+
702
+ // const auto t_start_us = ggml_time_us();
703
+
704
+ gf = model.build_graph (gparams);
705
+
706
+ // LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
707
+
708
+ if (!gf) {
709
+ LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
710
+ ret = GGML_STATUS_FAILED;
711
+ return nullptr ;
712
+ }
713
+
714
+ if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
715
+ LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
716
+ ret = GGML_STATUS_ALLOC_FAILED;
717
+ return nullptr ;
718
+ }
708
719
}
709
720
710
721
res->set_inputs (&ubatch);
711
722
712
- const auto status = graph_compute (gf , ubatch.n_tokens > 1 );
723
+ const auto status = graph_compute (res-> get_gf () , ubatch.n_tokens > 1 );
713
724
if (status != GGML_STATUS_SUCCESS) {
714
725
LLAMA_LOG_ERROR (" %s: failed to compute graph, compute status: %d\n " , __func__, status);
715
726
ret = status;
@@ -767,9 +778,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767
778
768
779
n_outputs = n_tokens;
769
780
770
- ggml_backend_sched_reset (sched.get ());
771
- ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
772
-
773
781
const auto causal_attn_org = cparams.causal_attn ;
774
782
775
783
// always use non-causal attention for encoder graphs
@@ -778,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778
786
cparams.causal_attn = false ;
779
787
780
788
ggml_status status;
781
- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
789
+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
782
790
783
791
cparams.causal_attn = causal_attn_org;
784
792
@@ -846,7 +854,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
846
854
847
855
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848
856
// overlap with device computation.
849
- ggml_backend_sched_reset (sched.get ());
857
+ if (!cparams.graph_reuse ) {
858
+ ggml_backend_sched_reset (sched.get ());
859
+ }
850
860
851
861
// TODO: hacky solution
852
862
if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1015,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
1015
n_outputs = n_outputs_new;
1006
1016
}
1007
1017
1008
- ggml_backend_sched_reset (sched.get ());
1009
- ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1010
-
1011
1018
ggml_status status;
1012
- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
1019
+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
1013
1020
1014
1021
if (!res) {
1015
1022
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1199,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
1192
1199
1193
1200
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1194
1201
// overlap with device computation.
1195
- ggml_backend_sched_reset (sched.get ());
1202
+ if (!cparams.graph_reuse ) {
1203
+ ggml_backend_sched_reset (sched.get ());
1204
+ }
1196
1205
1197
1206
return 0 ;
1198
1207
}
@@ -1275,20 +1284,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1275
1284
// graph
1276
1285
//
1277
1286
1278
- int32_t llama_context::graph_max_nodes () const {
1279
- return std::max<int32_t >(65536 , 5 *model.n_tensors ());
1280
- }
1281
-
1282
- ggml_cgraph * llama_context::graph_init () {
1283
- ggml_init_params params = {
1284
- /* .mem_size =*/ buf_compute_meta.size (),
1285
- /* .mem_buffer =*/ buf_compute_meta.data (),
1286
- /* .no_alloc =*/ true ,
1287
- };
1288
-
1289
- ctx_compute.reset (ggml_init (params));
1290
-
1291
- return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
1287
+ uint32_t llama_context::graph_max_nodes () const {
1288
+ return std::max<uint32_t >(65536u , 5u *model.n_tensors ());
1292
1289
}
1293
1290
1294
1291
ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1298,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1301
1298
LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
1302
1299
}
1303
1300
1301
+ gf_res_prev->reset ();
1302
+ ggml_backend_sched_reset (sched.get ());
1303
+
1304
1304
// store the n_outputs as it is, and restore it afterwards
1305
1305
// TODO: not sure if needed, might simplify in the future by removing this
1306
1306
const auto save_n_outputs = this ->n_outputs ;
@@ -1310,17 +1310,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1310
1310
llama_batch_allocr balloc (model.hparams .n_pos_per_embd ());
1311
1311
llama_ubatch ubatch = balloc.ubatch_reserve (n_tokens/n_seqs, n_seqs);
1312
1312
1313
- auto * gf = graph_init ();
1314
- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1313
+ auto * res = gf_res_reserve.get ();
1315
1314
1316
- this -> n_outputs = save_n_outputs ;
1315
+ const auto gparams = graph_params (res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT) ;
1317
1316
1318
- if (!res) {
1319
- LLAMA_LOG_ERROR (" %s: failed to build worst-case graph\n " , __func__);
1320
- return nullptr ;
1321
- }
1317
+ res->reset ();
1322
1318
1323
- ggml_backend_sched_reset (sched.get ());
1319
+ auto * gf = model.build_graph (gparams);
1320
+
1321
+ this ->n_outputs = save_n_outputs;
1324
1322
1325
1323
// initialize scheduler with the specified graph
1326
1324
if (!ggml_backend_sched_reserve (sched.get (), gf)) {
@@ -1331,28 +1329,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1331
1329
return gf;
1332
1330
}
1333
1331
1334
- llm_graph_result_ptr llama_context::graph_build (
1335
- ggml_context * ctx,
1336
- ggml_cgraph * gf,
1337
- const llama_ubatch & ubatch,
1338
- llm_graph_type gtype,
1339
- const llama_memory_context_i * mctx) {
1340
- return model.build_graph (
1341
- {
1342
- /* .ctx =*/ ctx,
1343
- /* .arch =*/ model.arch ,
1344
- /* .hparams =*/ model.hparams ,
1345
- /* .cparams =*/ cparams,
1346
- /* .ubatch =*/ ubatch,
1347
- /* .sched =*/ sched.get (),
1348
- /* .backend_cpu =*/ backend_cpu,
1349
- /* .cvec =*/ &cvec,
1350
- /* .loras =*/ &loras,
1351
- /* .mctx =*/ mctx,
1352
- /* .cross =*/ &cross,
1353
- /* .n_outputs =*/ n_outputs,
1354
- /* .cb =*/ graph_get_cb (),
1355
- }, gf, gtype);
1332
+ llm_graph_params llama_context::graph_params (
1333
+ llm_graph_result_i * res,
1334
+ const llama_ubatch & ubatch,
1335
+ const llama_memory_context_i * mctx,
1336
+ llm_graph_type gtype) const {
1337
+ return {
1338
+ /* .arch =*/ model.arch ,
1339
+ /* .hparams =*/ model.hparams ,
1340
+ /* .cparams =*/ cparams,
1341
+ /* .ubatch =*/ ubatch,
1342
+ /* .gtype =*/ gtype,
1343
+ /* .sched =*/ sched.get (),
1344
+ /* .backend_cpu =*/ backend_cpu,
1345
+ /* .cvec =*/ &cvec,
1346
+ /* .loras =*/ &loras,
1347
+ /* .mctx =*/ mctx,
1348
+ /* .cross =*/ &cross,
1349
+ /* .n_outputs =*/ n_outputs,
1350
+ /* .cb =*/ graph_get_cb (),
1351
+ /* .res =*/ res,
1352
+ };
1356
1353
}
1357
1354
1358
1355
ggml_status llama_context::graph_compute (
@@ -1930,6 +1927,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
1930
1927
data.t_eval_ms = 1e-3 * t_eval_us;
1931
1928
data.n_p_eval = std::max (1 , n_p_eval);
1932
1929
data.n_eval = std::max (1 , n_eval);
1930
+ data.n_reused = std::max (0 , n_reused);
1933
1931
1934
1932
return data;
1935
1933
}
@@ -1938,6 +1936,7 @@ void llama_context::perf_reset() {
1938
1936
t_start_us = ggml_time_us ();
1939
1937
t_eval_us = n_eval = 0 ;
1940
1938
t_p_eval_us = n_p_eval = 0 ;
1939
+ n_reused = 0 ;
1941
1940
}
1942
1941
1943
1942
//
@@ -2064,8 +2063,13 @@ void llama_context::opt_epoch_iter(
2064
2063
break ;
2065
2064
}
2066
2065
2067
- auto * gf = graph_init ();
2068
- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get ());
2066
+ auto * res = gf_res_prev.get ();
2067
+
2068
+ const auto gparams = graph_params (res, ubatch, mctx.get (), LLM_GRAPH_TYPE_DEFAULT);
2069
+
2070
+ res->reset ();
2071
+
2072
+ auto * gf = model.build_graph (gparams);
2069
2073
2070
2074
struct ggml_context * ctx_compute_opt;
2071
2075
{
@@ -2187,6 +2191,7 @@ llama_context_params llama_context_default_params() {
2187
2191
/* .no_perf =*/ true ,
2188
2192
/* .op_offload =*/ true ,
2189
2193
/* .swa_full =*/ true ,
2194
+ /* .graph_reuse =*/ false ,
2190
2195
};
2191
2196
2192
2197
return result;
@@ -2807,6 +2812,7 @@ void llama_perf_context_print(const llama_context * ctx) {
2807
2812
LLAMA_LOG_INFO (" %s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n " ,
2808
2813
__func__, data.t_eval_ms , data.n_eval , data.t_eval_ms / data.n_eval , 1e3 / data.t_eval_ms * data.n_eval );
2809
2814
LLAMA_LOG_INFO (" %s: total time = %10.2f ms / %5d tokens\n " , __func__, (t_end_ms - data.t_start_ms ), (data.n_p_eval + data.n_eval ));
2815
+ LLAMA_LOG_INFO (" %s: graphs reused = %10d\n " , __func__, data.n_reused );
2810
2816
}
2811
2817
2812
2818
void llama_perf_context_reset (llama_context * ctx) {
0 commit comments