Skip to content

Commit 6179578

Browse files
committed
batch : require non-coupled batch with sequential split_equal
ggml-ci
1 parent 5eb1a88 commit 6179578

File tree

6 files changed

+66
-14
lines changed

6 files changed

+66
-14
lines changed

src/llama-batch.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166166

167167
// note: tracking the other way around is not necessary for now
168168
//seq_cpl[s0][s1] = true;
169+
170+
has_cpl = true;
169171
}
170172
}
171173
}
@@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
403405
return n_outputs;
404406
}
405407

408+
uint32_t llama_batch_allocr::get_n_used() const {
409+
return n_used;
410+
}
411+
406412
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
407413
return out_ids;
408414
}
@@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
418424
void llama_batch_allocr::split_reset() {
419425
out_ids.clear();
420426

427+
n_used = 0;
428+
421429
used.clear();
422430
used.resize(get_n_tokens(), false);
423431

@@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
442450
idxs.push_back(cur_idx);
443451

444452
used[cur_idx] = true;
453+
++n_used;
445454

446455
++cur_idx;
447456

@@ -458,6 +467,12 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
458467
}
459468

460469
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
470+
if (sequential && has_cpl) {
471+
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
472+
473+
return {};
474+
}
475+
461476
std::vector<seq_set_t> cur_seq_set;
462477

463478
llama_seq_id last_seq_id = -1;
@@ -536,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential)
536551
idxs_per_seq[s].push_back(idx);
537552

538553
used[idx] = true;
554+
++n_used;
539555

540556
++cur_idx[s];
541557
}
@@ -577,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
577593
idxs.push_back(cur_idx);
578594

579595
used[cur_idx] = true;
596+
++n_used;
580597

581598
if (idxs.size() >= n_ubatch) {
582599
break;

src/llama-batch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class llama_batch_allocr {
5454

5555
uint32_t get_n_tokens() const;
5656
uint32_t get_n_outputs() const;
57+
uint32_t get_n_used() const;
5758

5859
// the array of output indices in the order they were encountered during the ubatch splitting
5960
std::vector<int32_t> & get_out_ids();
@@ -113,6 +114,9 @@ class llama_batch_allocr {
113114
using pos_set_t = std::set<llama_pos>;
114115
using seq_cpl_t = std::vector<bool>;
115116

117+
// helper flag to quickly determine if there are any coupled sequences in the batch
118+
bool has_cpl;
119+
116120
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
117121
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
118122

@@ -126,6 +130,8 @@ class llama_batch_allocr {
126130
// batch indices of the output
127131
std::vector<int32_t> out_ids;
128132

133+
uint32_t n_used;
134+
129135
// used[i] indicates if token i has already been used in a previous ubatch
130136
std::vector<bool> used;
131137

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
119119
ubatches.push_back(std::move(ubatch)); // NOLINT
120120
}
121121

122+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
123+
// failed to find a suitable split
124+
break;
125+
}
126+
122127
auto sinfos_base = kv_base->prepare(ubatches);
123128
if (sinfos_base.empty()) {
124129
break;
@@ -150,6 +155,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
150155
ubatches.push_back(std::move(ubatch)); // NOLINT
151156
}
152157

158+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
159+
// failed to find a suitable split
160+
break;
161+
}
162+
153163
auto sinfos_base = kv_base->prepare(ubatches);
154164
if (sinfos_base.empty()) {
155165
break;

src/llama-kv-cache-unified.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
427427
ubatches.push_back(std::move(ubatch)); // NOLINT
428428
}
429429

430+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
431+
// failed to find a suitable split
432+
break;
433+
}
434+
430435
auto sinfos = prepare(ubatches);
431436
if (sinfos.empty()) {
432437
break;

src/llama-memory-hybrid.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
8181
ubatches.push_back(std::move(ubatch)); // NOLINT
8282
}
8383

84+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
85+
// failed to find a suitable split
86+
break;
87+
}
88+
8489
// prepare the recurrent batches first
8590
if (!mem_recr->prepare(ubatches)) {
8691
// TODO: will the recurrent cache be in an undefined context at this point?

src/llama-memory-recurrent.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -365,26 +365,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
365365
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366366
std::vector<llama_ubatch> ubatches;
367367

368-
while (true) {
369-
llama_ubatch ubatch;
368+
do {
369+
balloc.split_reset();
370370

371-
if (embd_all) {
372-
// if all tokens are output, split by sequence
373-
ubatch = balloc.split_seq(n_ubatch);
374-
} else {
375-
ubatch = balloc.split_equal(n_ubatch, false);
371+
while (true) {
372+
llama_ubatch ubatch;
373+
374+
if (embd_all) {
375+
// if all tokens are output, split by sequence
376+
ubatch = balloc.split_seq(n_ubatch);
377+
} else {
378+
ubatch = balloc.split_equal(n_ubatch, false);
379+
}
380+
381+
if (ubatch.n_tokens == 0) {
382+
break;
383+
}
384+
385+
ubatches.push_back(std::move(ubatch)); // NOLINT
376386
}
377387

378-
if (ubatch.n_tokens == 0) {
388+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
389+
// failed to find a suitable split
379390
break;
380391
}
381392

382-
ubatches.push_back(std::move(ubatch)); // NOLINT
383-
}
384-
385-
if (!prepare(ubatches)) {
386-
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
387-
}
393+
if (!prepare(ubatches)) {
394+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
395+
}
396+
} while (false);
388397

389398
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
390399
}

0 commit comments

Comments
 (0)