Skip to content

Commit a823406

Browse files
committed
cont : add n_seq_max to batch allocr
ggml-ci
1 parent 21f865d commit a823406

File tree

4 files changed

+24
-15
lines changed

4 files changed

+24
-15
lines changed

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
107107
const llama_vocab * vocab = llama_model_get_vocab(model);
108108

109109
const int n_ctx_train = llama_model_n_ctx_train(model);
110-
const int n_ctx = llama_n_ctx(ctx);
110+
const int n_ctx = llama_n_ctx(ctx);
111111

112112
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
113113

src/llama-batch.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
2727
const llama_vocab & vocab,
2828
const llama_memory_i * memory,
2929
uint32_t n_embd,
30+
uint32_t n_seq_max,
3031
bool output_all) {
3132
clear();
3233

@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
4041
// validate input batch
4142
//
4243

44+
if (n_seq_max > LLAMA_MAX_SEQ) {
45+
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46+
return false;
47+
}
48+
4349
if (batch.token) {
4450
for (int32_t i = 0; i < batch.n_tokens; ++i) {
4551
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
5258
if (batch.seq_id) {
5359
for (int32_t i = 0; i < batch.n_tokens; ++i) {
5460
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
55-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
56-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
61+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
5763
return false;
5864
}
5965
}
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
8692

8793
// initialize the starting position for each sequence based on the positions in the memory
8894
llama_pos p0[LLAMA_MAX_SEQ];
89-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
95+
for (uint32_t s = 0; s < n_seq_max; ++s) {
9096
if (!memory) {
9197
// if no memory -> start from 0
9298
p0[s] = 0;
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
143149
// compute stats
144150
//
145151

146-
this->n_embd = n_embd;
152+
this->n_embd = n_embd;
153+
this->n_seq_max = n_seq_max;
147154

148155
// count the outputs in this batch
149156
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
189196
seq_set_map[cur].push_back(i);
190197
}
191198

192-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
199+
for (uint32_t s = 0; s < n_seq_max; ++s) {
193200
if (seq_set_unq.test(s)) {
194201
seq_idx[s] = seq_id_unq.size();
195202
seq_id_unq.push_back(s);
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
241248
// consistency checks
242249
//
243250

244-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
251+
for (uint32_t s = 0; s < n_seq_max; ++s) {
245252
if (seq_pos[s].empty()) {
246253
continue;
247254
}
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
284291
}
285292

286293
if (memory) {
287-
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
288-
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
294+
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
295+
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
289296
if (seq_cpl[s0][s1]) {
290297
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
291298
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
316323
//
317324
{
318325
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
319-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
326+
for (uint32_t s = 0; s < n_seq_max; ++s) {
320327
cur_seq_set[s].set();
321328
}
322329

323330
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
324-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
331+
for (uint32_t s = 0; s < n_seq_max; ++s) {
325332
cur_seq_pos[s] = -1;
326333
}
327334

@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
692699
}
693700
}
694701

695-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
702+
for (uint32_t s = 0; s < n_seq_max; ++s) {
696703
if (seq_set_unq.test(s)) {
697704
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
698705
ubatch.seq_id_unq.push_back(s);

src/llama-batch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class llama_batch_allocr {
4848
const llama_vocab & vocab,
4949
const llama_memory_i * memory,
5050
uint32_t n_embd,
51+
uint32_t n_seq_max,
5152
bool output_all);
5253

5354
const llama_batch & get_batch() const;
@@ -100,6 +101,7 @@ class llama_batch_allocr {
100101
const uint32_t n_pos_per_embd;
101102

102103
uint32_t n_embd;
104+
uint32_t n_seq_max;
103105
uint32_t n_outputs;
104106

105107
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

src/llama-context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
740740
const int64_t n_embd = hparams.n_embd;
741741

742742
// note: during encode, we always pass the full sequence starting from pos = 0
743-
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
743+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.attn_streams ? cparams.n_seq_max : LLAMA_MAX_SEQ, true)) {
744744
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
745745
return -1;
746746
}
@@ -907,7 +907,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
907907
// when computing embeddings, all tokens are output
908908
const bool output_all = cparams.embeddings;
909909

910-
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
910+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.attn_streams ? cparams.n_seq_max : LLAMA_MAX_SEQ, output_all)) {
911911
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
912912
return -1;
913913
}
@@ -2036,7 +2036,7 @@ void llama_context::opt_epoch_iter(
20362036
batch.logits [pos_batch] = true;
20372037
}
20382038

2039-
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2039+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.attn_streams ? cparams.n_seq_max : LLAMA_MAX_SEQ, true)) {
20402040
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
20412041
return;
20422042
}

0 commit comments

Comments
 (0)