Skip to content

Commit eb5856c

Browse files
committed
llama : add "virtual sequences"
ggml-ci
1 parent a70c8a0 commit eb5856c

15 files changed

+641
-248
lines changed

examples/parallel/parallel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
235235

236236
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
237237
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
238-
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
238+
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
239239

240240
int32_t n_total_prompt = 0;
241241
int32_t n_total_gen = 0;
@@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
289289
// all sequences have ended - clear the entire KV cache
290290
for (int i = 1; i <= n_clients; ++i) {
291291
llama_memory_seq_rm(mem, i, -1, -1);
292+
292293
// but keep the system prompt
293294
llama_memory_seq_cp(mem, 0, i, -1, -1);
294295
}

src/llama-batch.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166166

167167
// note: tracking the other way around is not necessary for now
168168
//seq_cpl[s0][s1] = true;
169+
170+
has_cpl = true;
169171
}
170172
}
171173
}
@@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
405407
return n_outputs;
406408
}
407409

410+
uint32_t llama_batch_allocr::get_n_used() const {
411+
return n_used;
412+
}
413+
408414
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409415
return out_ids;
410416
}
@@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
420426
void llama_batch_allocr::split_reset() {
421427
out_ids.clear();
422428

429+
n_used = 0;
430+
423431
used.clear();
424432
used.resize(get_n_tokens(), false);
425433

@@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
444452
idxs.push_back(cur_idx);
445453

446454
used[cur_idx] = true;
455+
++n_used;
447456

448457
++cur_idx;
449458

@@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
459468
return ubatch_add(idxs, idxs.size(), false);
460469
}
461470

462-
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
471+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
472+
if (sequential && has_cpl) {
473+
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
474+
475+
return {};
476+
}
477+
463478
std::vector<seq_set_t> cur_seq_set;
464479

480+
llama_seq_id last_seq_id = -1;
481+
465482
// determine the non-overlapping sequence sets participating in this ubatch
466483
for (int32_t i = 0; i < batch.n_tokens; ++i) {
467484
if (used[i]) {
@@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
478495
}
479496
}
480497

498+
// accept only increasing sequence ids
499+
if (sequential) {
500+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
501+
}
502+
481503
if (add) {
482504
cur_seq_set.push_back(seq_set[i]);
483505

506+
last_seq_id = batch.seq_id[i][0];
507+
484508
if (cur_seq_set.size() > n_ubatch) {
485509
break;
486510
}
@@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
529553
idxs_per_seq[s].push_back(idx);
530554

531555
used[idx] = true;
556+
++n_used;
532557

533558
++cur_idx[s];
534559
}
@@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
570595
idxs.push_back(cur_idx);
571596

572597
used[cur_idx] = true;
598+
++n_used;
573599

574600
if (idxs.size() >= n_ubatch) {
575601
break;

src/llama-batch.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class llama_batch_allocr {
5454

5555
uint32_t get_n_tokens() const;
5656
uint32_t get_n_outputs() const;
57+
uint32_t get_n_used() const;
5758

5859
// the array of output indices in the order they were encountered during the ubatch splitting
5960
std::vector<int32_t> & get_out_ids();
@@ -69,7 +70,8 @@ class llama_batch_allocr {
6970
llama_ubatch split_simple(uint32_t n_ubatch);
7071

7172
// make ubatches of equal-length sequences sets
72-
llama_ubatch split_equal(uint32_t n_ubatch);
73+
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
74+
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
7375

7476
// sequence-set-wise split - each ubatch contains a single sequence-set
7577
llama_ubatch split_seq(uint32_t n_ubatch);
@@ -112,6 +114,9 @@ class llama_batch_allocr {
112114
using pos_set_t = std::set<llama_pos>;
113115
using seq_cpl_t = std::vector<bool>;
114116

117+
// helper flag to quickly determine if there are any coupled sequences in the batch
118+
bool has_cpl;
119+
115120
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116121
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117122

@@ -125,6 +130,8 @@ class llama_batch_allocr {
125130
// batch indices of the output
126131
std::vector<int32_t> out_ids;
127132

133+
uint32_t n_used;
134+
128135
// used[i] indicates if token i has already been used in a previous ubatch
129136
std::vector<bool> used;
130137

src/llama-context.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ llama_context::llama_context(
3333
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

36+
const char * LLAMA_HT = getenv("LLAMA_HT");
37+
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
38+
3639
cparams.n_threads = params.n_threads;
3740
cparams.n_threads_batch = params.n_threads_batch;
3841
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -1308,7 +1311,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13081311
this->n_outputs = n_outputs;
13091312

13101313
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311-
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1314+
//llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1315+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens, 1);
13121316

13131317
auto * gf = graph_init();
13141318
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);

src/llama-cparams.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ struct llama_cparams {
1111
uint32_t n_batch;
1212
uint32_t n_ubatch;
1313
uint32_t n_seq_max;
14-
int n_threads; // number of threads to use for generation
15-
int n_threads_batch; // number of threads to use for batch processing
14+
uint32_t n_seq_virt;
15+
int32_t n_threads; // number of threads to use for generation
16+
int32_t n_threads_batch; // number of threads to use for batch processing
1617

1718
float rope_freq_base;
1819
float rope_freq_scale;

src/llama-graph.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,13 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10001000
{
10011001
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10021002

1003-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1003+
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1004+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
10041005

10051006
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
10061007
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
10071008

1008-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1009-
//cb(inp->self_kq_mask, "KQ_mask", -1);
1009+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
10101010
ggml_set_input(inp->self_kq_mask);
10111011

10121012
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1033,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10331033
float kq_scale) const {
10341034
const bool v_trans = v->nb[1] > v->nb[2];
10351035

1036+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1037+
1038+
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
1039+
10361040
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
10371041
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
10381042
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
@@ -1081,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10811085
#endif
10821086
}
10831087

1084-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1088+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
10851089
} else {
10861090
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
10871091

@@ -1126,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11261130

11271131
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
11281132

1129-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1133+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
11301134

11311135
if (!cparams.offload_kqv) {
11321136
// all nodes between the KV store and the attention output are run on the CPU
@@ -1204,12 +1208,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12041208
{
12051209
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12061210

1207-
const auto n_kv = mctx_cur->get_n_kv();
1211+
const auto n_kv = mctx_cur->get_n_kv();
1212+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
12081213

12091214
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
12101215
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12111216

1212-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1217+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
12131218
ggml_set_input(inp->self_kq_mask);
12141219

12151220
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1451,13 +1456,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14511456

14521457
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14531458

1459+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1460+
14541461
{
14551462
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14561463

14571464
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
14581465
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14591466

1460-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1467+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
14611468
ggml_set_input(inp->self_kq_mask);
14621469

14631470
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1471,7 +1478,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14711478
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
14721479
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14731480

1474-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1481+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
14751482
ggml_set_input(inp->self_kq_mask_swa);
14761483

14771484
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
255255
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
256256

257257
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
258+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
259259

260-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
261-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
260+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
261+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
262262

263263
const llama_hparams & hparams;
264264
const llama_cparams & cparams;
@@ -289,14 +289,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
289289
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
290290

291291
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
292+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
293293
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294-
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
294+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
295295

296-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
297-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
298-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
299-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
296+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
297+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
298+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
299+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
300300

301301
const llama_hparams & hparams;
302302
const llama_cparams & cparams;

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
2020
bool swa_full,
2121
uint32_t kv_size,
2222
uint32_t n_seq_max,
23+
uint32_t n_seq_virt,
2324
uint32_t n_ubatch,
24-
uint32_t n_pad) : hparams(model.hparams) {
25+
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
2526
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
2627
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
2728

2829
const uint32_t size_base = kv_size;
2930

30-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
3132

3233
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
3334
if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
4142

4243
kv_base = std::make_unique<llama_kv_cache_unified>(
4344
model, std::move(filter_base), type_k, type_v,
44-
v_trans, offload, size_base, n_seq_max, n_pad,
45+
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
4546
0, LLAMA_SWA_TYPE_NONE);
4647

4748
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
4849

4950
kv_swa = std::make_unique<llama_kv_cache_unified>(
5051
model, std::move(filter_swa), type_k, type_v,
51-
v_trans, offload, size_swa, n_seq_max, n_pad,
52+
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
5253
hparams.n_swa, hparams.swa_type);
5354
}
5455

@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
100101

101102
// first try simple split
102103
do {
104+
if (n_seq_virt > 1) {
105+
// requires equal splits, so we skip the simple split
106+
break;
107+
}
108+
103109
balloc.split_reset();
104110

105111
std::vector<llama_ubatch> ubatches;
@@ -113,6 +119,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113119
ubatches.push_back(std::move(ubatch)); // NOLINT
114120
}
115121

122+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
123+
// failed to find a suitable split
124+
break;
125+
}
126+
116127
auto sinfos_base = kv_base->prepare(ubatches);
117128
if (sinfos_base.empty()) {
118129
break;
@@ -135,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
135146

136147
std::vector<llama_ubatch> ubatches;
137148
while (true) {
138-
auto ubatch = balloc.split_equal(n_ubatch);
149+
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
139150

140151
if (ubatch.n_tokens == 0) {
141152
break;
@@ -144,6 +155,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144155
ubatches.push_back(std::move(ubatch)); // NOLINT
145156
}
146157

158+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
159+
// failed to find a suitable split
160+
break;
161+
}
162+
147163
auto sinfos_base = kv_base->prepare(ubatches);
148164
if (sinfos_base.empty()) {
149165
break;

src/llama-kv-cache-unified-iswa.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
2222
bool swa_full,
2323
uint32_t kv_size,
2424
uint32_t n_seq_max,
25+
uint32_t n_seq_virt,
2526
uint32_t n_ubatch,
2627
uint32_t n_pad);
2728

@@ -68,6 +69,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
6869
private:
6970
const llama_hparams & hparams;
7071

72+
const uint32_t n_seq_virt = 1;
73+
7174
std::unique_ptr<llama_kv_cache_unified> kv_base;
7275
std::unique_ptr<llama_kv_cache_unified> kv_swa;
7376
};

0 commit comments

Comments
 (0)