Skip to content

Commit de6e65f

Browse files
ggerganovarthw
authored andcommitted
llama : refactor llama_context, llama_kv_cache, llm_build_context (ggml-org#12181)
* llama : refactor llama_context, llama_kv_cache, llm_build_context ggml-ci * graph : don't mutate the KV cache during defrag ggml-ci * context : reduce virtuals + remove test function ggml-ci * context : move interface implementation to source file + factory ggml-ci * graph : move KV cache build functions to llama_context impl ggml-ci * graph : remove model reference from build_pooling ggml-ci * graph : remove llama_model reference ggml-ci * kv_cache : provide rope factors ggml-ci * graph : rework inputs to use only unique_ptr, remove attn input abstraction ggml-ci * context : remove llama_context_i abstraction ggml-ci * context : clean-up ggml-ci * graph : clean-up ggml-ci * llama : remove redundant keywords (struct, enum) ggml-ci * model : adapt gemma3 ggml-ci * graph : restore same attention ops as on master ggml-ci * llama : remove TODO + fix indent ggml-ci
1 parent a573718 commit de6e65f

File tree

46 files changed

+13785
-12072
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+13785
-12072
lines changed

common/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,8 +955,8 @@ struct common_init_result common_init_from_params(common_params & params) {
955955
return iparams;
956956
}
957957

958-
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
959-
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
958+
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
959+
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
960960
params.ctx_shift = false;
961961
}
962962

@@ -1060,7 +1060,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10601060
if (llama_model_has_decoder(model)) {
10611061
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10621062
}
1063-
llama_kv_cache_clear(lctx);
1063+
llama_kv_self_clear(lctx);
10641064
llama_synchronize(lctx);
10651065
llama_perf_context_reset(lctx);
10661066
}

common/speculative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ llama_tokens common_speculative_gen_draft(
173173
result.reserve(params.n_draft);
174174

175175
if (reuse_n == 0) {
176-
llama_kv_cache_clear(ctx);
176+
llama_kv_self_clear(ctx);
177177

178178
prompt.clear();
179179
} else {
@@ -192,14 +192,14 @@ llama_tokens common_speculative_gen_draft(
192192
}
193193

194194
if (reuse_i > 0) {
195-
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
196-
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
195+
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
196+
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
197197

198198
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
199199
}
200200

201201
if (reuse_n < (int) prompt.size()) {
202-
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
202+
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
203203

204204
prompt.erase(prompt.begin() + reuse_n, prompt.end());
205205
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
132132

133133
const auto t_pp_start = ggml_time_us();
134134

135-
llama_kv_cache_clear(ctx);
135+
llama_kv_self_clear(ctx);
136136

137137
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
138138
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
141141

142142
if (is_pp_shared) {
143143
for (int32_t i = 1; i < pl; ++i) {
144-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
144+
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
145145
}
146146
}
147147

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
116116
}
117117

118118
for i in 1 ..< n_parallel {
119-
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
119+
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
120120
}
121121

122122
if n_parallel > 1 {

examples/cvector-generator/cvector-generator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
342342
}
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345-
llama_kv_cache_clear(ctx);
345+
llama_kv_self_clear(ctx);
346346
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
347347
fprintf(stderr, "%s : failed to eval\n", __func__);
348348
return false;

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
3838
const struct llama_model * model = llama_get_model(ctx);
3939

4040
// clear previous kv_cache values (irrelevant for embeddings)
41-
llama_kv_cache_clear(ctx);
41+
llama_kv_self_clear(ctx);
4242

4343
// run model
4444
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4545
}
4646

4747
// clear previous kv_cache values (irrelevant for embeddings)
48-
llama_kv_cache_clear(ctx);
48+
llama_kv_self_clear(ctx);
4949
llama_set_embeddings(ctx, true);
5050
llama_set_causal_attn(ctx, false);
5151

@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
102102

103103
llama_token eos_token = llama_vocab_eos(vocab);
104104

105-
llama_kv_cache_clear(ctx);
105+
llama_kv_self_clear(ctx);
106106
llama_set_embeddings(ctx, false);
107107
llama_set_causal_attn(ctx, true);
108108

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
495495
const auto t_start = std::chrono::high_resolution_clock::now();
496496

497497
// clear the KV cache
498-
llama_kv_cache_clear(ctx);
498+
llama_kv_self_clear(ctx);
499499

500500
llama_batch batch = llama_batch_init(n_batch, 0, 1);
501501

examples/infill/infill.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ int main(int argc, char ** argv) {
332332
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
333333
n_past, n_left, n_ctx, params.n_keep, n_discard);
334334

335-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
335+
llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336+
llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
337337

338338
n_past -= n_discard;
339339

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ int main(int argc, char ** argv) {
15781578

15791579
test t(inst, lmodel, ctx);
15801580

1581-
llama_kv_cache_clear(ctx);
1581+
llama_kv_self_clear(ctx);
15821582

15831583
// cool off before the test
15841584
if (params.delay) {
@@ -1618,7 +1618,7 @@ int main(int argc, char ** argv) {
16181618
}
16191619

16201620
for (int i = 0; i < params.reps; i++) {
1621-
llama_kv_cache_clear(ctx);
1621+
llama_kv_self_clear(ctx);
16221622

16231623
uint64_t t_start = get_time_ns();
16241624

0 commit comments

Comments
 (0)