@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
27
27
const llama_vocab & vocab,
28
28
const llama_memory_i * memory,
29
29
uint32_t n_embd,
30
+ uint32_t n_seq_max,
30
31
bool output_all) {
31
32
clear ();
32
33
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
40
41
// validate input batch
41
42
//
42
43
44
+ if (n_seq_max > LLAMA_MAX_SEQ) {
45
+ LLAMA_LOG_ERROR (" %s: n_seq_max = %d > %d\n " , __func__, n_seq_max, LLAMA_MAX_SEQ);
46
+ return false ;
47
+ }
48
+
43
49
if (batch.token ) {
44
50
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
45
51
if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
52
58
if (batch.seq_id ) {
53
59
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
54
60
for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
55
- if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= LLAMA_MAX_SEQ )) {
56
- LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], LLAMA_MAX_SEQ );
61
+ if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= (llama_seq_id) n_seq_max )) {
62
+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], (llama_seq_id) n_seq_max );
57
63
return false ;
58
64
}
59
65
}
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
86
92
87
93
// initialize the starting position for each sequence based on the positions in the memory
88
94
llama_pos p0[LLAMA_MAX_SEQ];
89
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
95
+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
90
96
if (!memory) {
91
97
// if no memory -> start from 0
92
98
p0[s] = 0 ;
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
143
149
// compute stats
144
150
//
145
151
146
- this ->n_embd = n_embd;
152
+ this ->n_embd = n_embd;
153
+ this ->n_seq_max = n_seq_max;
147
154
148
155
// count the outputs in this batch
149
156
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
189
196
seq_set_map[cur].push_back (i);
190
197
}
191
198
192
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
199
+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
193
200
if (seq_set_unq.test (s)) {
194
201
seq_idx[s] = seq_id_unq.size ();
195
202
seq_id_unq.push_back (s);
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
241
248
// consistency checks
242
249
//
243
250
244
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
251
+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
245
252
if (seq_pos[s].empty ()) {
246
253
continue ;
247
254
}
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
284
291
}
285
292
286
293
if (memory) {
287
- for (int32_t s0 = 0 ; s0 < LLAMA_MAX_SEQ ; ++s0) {
288
- for (int32_t s1 = 0 ; s1 < LLAMA_MAX_SEQ ; ++s1) {
294
+ for (uint32_t s0 = 0 ; s0 < n_seq_max ; ++s0) {
295
+ for (uint32_t s1 = 0 ; s1 < n_seq_max ; ++s1) {
289
296
if (seq_cpl[s0][s1]) {
290
297
if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
291
298
memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
316
323
//
317
324
{
318
325
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
319
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
326
+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
320
327
cur_seq_set[s].set ();
321
328
}
322
329
323
330
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
324
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
331
+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
325
332
cur_seq_pos[s] = -1 ;
326
333
}
327
334
@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
692
699
}
693
700
}
694
701
695
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
702
+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
696
703
if (seq_set_unq.test (s)) {
697
704
ubatch.seq_idx [s] = ubatch.seq_id_unq .size ();
698
705
ubatch.seq_id_unq .push_back (s);
0 commit comments