Skip to content

Commit 291a785

Browse files
committed
llama : rename batch.logits to batch.output
This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings.
1 parent 9f4cc8f commit 291a785

File tree

19 files changed

+52
-53
lines changed

19 files changed

+52
-53
lines changed

common/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
607607
<< ", pos " << std::to_string(batch.pos[i])
608608
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
609609
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
610-
<< ", logits " << std::to_string(batch.logits[i]);
610+
<< ", output " << std::to_string(batch.output[i]);
611611
}
612612

613613
buf << " ]";
@@ -1617,7 +1617,7 @@ void common_batch_add(
16171617
llama_token id,
16181618
llama_pos pos,
16191619
const std::vector<llama_seq_id> & seq_ids,
1620-
bool logits) {
1620+
bool output) {
16211621
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
16221622

16231623
batch.token [batch.n_tokens] = id;
@@ -1626,7 +1626,7 @@ void common_batch_add(
16261626
for (size_t i = 0; i < seq_ids.size(); ++i) {
16271627
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
16281628
}
1629-
batch.logits [batch.n_tokens] = logits;
1629+
batch.output [batch.n_tokens] = output;
16301630

16311631
batch.n_tokens++;
16321632
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
7373
batch.pos + i,
7474
batch.n_seq_id + i,
7575
batch.seq_id + i,
76-
batch.logits + i,
76+
batch.output + i,
7777
};
7878

7979
const int ret = llama_decode(ctx, batch_view);
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
common_batch_add(batch, 0, i, { j }, false);
129129
}
130130
}
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
const auto t_pp_start = ggml_time_us();
134134

examples/batched.swift/Sources/main.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() {
104104
if let seq_id = batch.seq_id[i] {
105105
seq_id[0] = 0
106106
}
107-
batch.logits[i] = 0
107+
batch.output[i] = 0
108108
}
109109

110110
// llama_decode will output logits only for the last token of the prompt
111-
batch.logits[Int(batch.n_tokens) - 1] = 1
111+
batch.output[Int(batch.n_tokens) - 1] = 1
112112

113113
if llama_decode(context, batch) != 0 {
114114
print("llama_decode() failed")
@@ -171,7 +171,7 @@ while n_cur <= n_len {
171171
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
172172
seq_id[0] = Int32(i)
173173
}
174-
batch.logits[Int(batch.n_tokens)] = 1
174+
batch.output[Int(batch.n_tokens)] = 1
175175

176176
i_batch[i] = batch.n_tokens
177177

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
131131
}
132132

133133
// llama_decode will output logits only for the last token of the prompt
134-
batch.logits[batch.n_tokens - 1] = true;
134+
batch.output[batch.n_tokens - 1] = true;
135135

136136
if (llama_decode(ctx, batch) != 0) {
137137
LOG_ERR("%s: llama_decode() failed\n", __func__);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5454
}
5555

5656
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
57+
if (!batch.output[i]) {
5858
continue;
5959
}
6060

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
193193
common_batch_add(*batch, 0, i, { 0 }, false);
194194
}
195195

196-
batch->logits[batch->n_tokens - 1] = true;
196+
batch->output[batch->n_tokens - 1] = true;
197197
llama_kv_cache_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
@@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
297297
for (int i = 0; i < n_tokens; ++i) {
298298
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
299299
}
300-
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
300+
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
301301

302302
return reinterpret_cast<jlong>(batch);
303303
}
@@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
381381
}
382382

383383
// llama_decode will output logits only for the last token of the prompt
384-
batch->logits[batch->n_tokens - 1] = true;
384+
batch->output[batch->n_tokens - 1] = true;
385385

386386
if (llama_decode(context, *batch) != 0) {
387387
LOGe("llama_decode() failed");

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) {
99
batch.n_tokens = 0
1010
}
1111

12-
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
12+
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) {
1313
batch.token [Int(batch.n_tokens)] = id
1414
batch.pos [Int(batch.n_tokens)] = pos
1515
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
1616
for i in 0..<seq_ids.count {
1717
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
1818
}
19-
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
19+
batch.outputs [Int(batch.n_tokens)] = outputs ? 1 : 0
2020

2121
batch.n_tokens += 1
2222
}
@@ -139,7 +139,7 @@ actor LlamaContext {
139139
let i = Int(i1)
140140
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
141141
}
142-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
142+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
143143

144144
if llama_decode(context, batch) != 0 {
145145
print("llama_decode() failed")
@@ -208,7 +208,7 @@ actor LlamaContext {
208208
for i in 0..<n_tokens {
209209
llama_batch_add(&batch, 0, Int32(i), [0], false)
210210
}
211-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
211+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
212212

213213
llama_kv_cache_clear(context)
214214

examples/llava/llava.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,13 @@ struct llava_embd_batch {
441441
std::vector<int32_t> n_seq_id;
442442
std::vector<llama_seq_id> seq_id_0;
443443
std::vector<llama_seq_id *> seq_ids;
444-
std::vector<int8_t> logits;
444+
std::vector<int8_t> outputs;
445445
llama_batch batch;
446446
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
447447
pos .resize(n_tokens);
448448
n_seq_id.resize(n_tokens);
449449
seq_ids .resize(n_tokens + 1);
450-
logits .resize(n_tokens);
450+
outputs .resize(n_tokens);
451451
seq_id_0.resize(1);
452452
seq_id_0[0] = seq_id;
453453
seq_ids [n_tokens] = nullptr;
@@ -458,13 +458,13 @@ struct llava_embd_batch {
458458
/*pos =*/ pos.data(),
459459
/*n_seq_id =*/ n_seq_id.data(),
460460
/*seq_id =*/ seq_ids.data(),
461-
/*logits =*/ logits.data(),
461+
/*output =*/ outputs.data(),
462462
};
463463
for (int i = 0; i < n_tokens; i++) {
464464
batch.pos [i] = pos_0 + i;
465465
batch.n_seq_id[i] = 1;
466466
batch.seq_id [i] = seq_id_0.data();
467-
batch.logits [i] = false;
467+
batch.output [i] = false;
468468
}
469469
}
470470
};

examples/parallel/parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ int main(int argc, char ** argv) {
266266

267267
// extract the logits only for the last token
268268
if (batch.n_tokens > 0) {
269-
batch.logits[batch.n_tokens - 1] = true;
269+
batch.output[batch.n_tokens - 1] = true;
270270
}
271271

272272
client.n_prompt = tokens_prompt.size();
@@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
309309
batch.pos + i,
310310
batch.n_seq_id + i,
311311
batch.seq_id + i,
312-
batch.logits + i,
312+
batch.output + i,
313313
};
314314

315315
const int ret = llama_decode(ctx, batch_view);

examples/passkey/passkey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ int main(int argc, char ** argv) {
146146
}
147147

148148
if (i + n_batch >= n_tokens_all) {
149-
batch.logits[batch.n_tokens - 1] = true;
149+
batch.output[batch.n_tokens - 1] = true;
150150
}
151151

152152
if (llama_decode(ctx, batch) != 0) {
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
180180
}
181181

182182
if (i + n_batch >= n_tokens_all) {
183-
batch.logits[batch.n_tokens - 1] = true;
183+
batch.output[batch.n_tokens - 1] = true;
184184
}
185185

186186
if (llama_decode(ctx, batch) != 0) {

0 commit comments

Comments
 (0)