6
6
7
7
#include " common.h"
8
8
#include " llama.h"
9
+ #include " grammar-parser.h"
9
10
10
11
#include < cmath>
11
12
#include < cstdio>
@@ -109,16 +110,35 @@ int main(int argc, char ** argv) {
109
110
// used to determine end of generation
110
111
bool has_eos = false ;
111
112
113
+ // grammar stuff
114
+ struct llama_grammar * grammar_dft = NULL ;
115
+ struct llama_grammar * grammar_tgt = NULL ;
116
+
117
+ grammar_parser::parse_state parsed_grammar;
118
+
119
+ // if requested - load the grammar, error checking is omitted for brevity
120
+ if (!params.grammar .empty ()) {
121
+ parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
122
+ // will be empty (default) if there are parse errors
123
+ if (parsed_grammar.rules .empty ()) {
124
+ return 1 ;
125
+ }
126
+
127
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
128
+ grammar_tgt = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
129
+ }
130
+
112
131
const auto t_dec_start = ggml_time_us ();
113
132
114
133
while (true ) {
115
134
LOG (" drafted: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx_dft, drafted));
116
135
117
- // sample from the drafted tokens if any
118
136
int i_dft = 0 ;
119
137
while (true ) {
120
- const llama_token id = llama_sample_token (ctx_tgt, NULL , NULL , params, last_tokens, candidates, i_dft);
138
+ // sample from the target model
139
+ const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt, params, last_tokens, candidates, i_dft);
121
140
141
+ // remember which tokens were sampled - used for repetition penalties during sampling
122
142
last_tokens.erase (last_tokens.begin ());
123
143
last_tokens.push_back (id);
124
144
@@ -134,8 +154,9 @@ int main(int argc, char ** argv) {
134
154
135
155
++n_predict;
136
156
157
+ // check if the draft matches the target
137
158
if (i_dft < (int ) drafted.size () && id == drafted[i_dft]) {
138
- LOG (" drafted token %d accepted\n " , id );
159
+ LOG (" the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n " , i_dft, id, token_str. c_str () );
139
160
++n_accept;
140
161
++n_past_tgt;
141
162
++n_past_dft;
@@ -145,6 +166,14 @@ int main(int argc, char ** argv) {
145
166
}
146
167
147
168
// the drafted token was rejected or we are out of drafted tokens
169
+
170
+ if (i_dft < (int ) drafted.size ()) {
171
+ LOG (" the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n " ,
172
+ i_dft, drafted[i_dft], llama_token_to_piece (ctx_dft, drafted[i_dft]).c_str (), id, token_str.c_str ());
173
+ } else {
174
+ LOG (" out of drafted tokens\n " );
175
+ }
176
+
148
177
llama_eval (ctx_dft, &id, 1 , n_past_dft, params.n_threads );
149
178
++n_past_dft;
150
179
@@ -158,7 +187,16 @@ int main(int argc, char ** argv) {
158
187
break ;
159
188
}
160
189
161
- // sample n_draft tokens from the draft model picking the best token
190
+ if (grammar_tgt) {
191
+ if (grammar_dft) {
192
+ llama_grammar_free (grammar_dft);
193
+ }
194
+ grammar_dft = llama_grammar_copy (grammar_tgt);
195
+
196
+ LOG (" copied target grammar to draft grammar\n " );
197
+ }
198
+
199
+ // sample n_draft tokens from the draft model using greedy decoding
162
200
int n_past_cur = n_past_dft;
163
201
for (int i = 0 ; i < n_draft; ++i) {
164
202
float * logits = llama_get_logits (ctx_dft);
@@ -170,32 +208,48 @@ int main(int argc, char ** argv) {
170
208
171
209
llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
172
210
211
+ if (grammar_dft != NULL ) {
212
+ llama_sample_grammar (ctx_dft, &cur_p, grammar_dft);
213
+ }
214
+
173
215
// computes softmax and sorts the candidates
174
216
llama_sample_softmax (ctx_dft, &cur_p);
175
217
176
218
for (int i = 0 ; i < 3 ; ++i) {
177
- LOG (" - draft candidate %d : %d (%.3f)\n " , i, cur_p.data [i].id , cur_p.data [i].p );
219
+ LOG (" - draft candidate %3d : %6d (%8 .3f) '%s' \n " , i, cur_p.data [i].id , cur_p.data [i].p , llama_token_to_piece (ctx_dft, cur_p. data [i]. id ). c_str () );
178
220
}
179
221
180
- // too low probability, stop drafting
222
+ // TODO: better logic?
181
223
if (cur_p.data [0 ].p < 2 *cur_p.data [1 ].p ) {
224
+ LOG (" stopping drafting, probability too low: %.3f < 2*%.3f\n " , cur_p.data [0 ].p , cur_p.data [1 ].p );
182
225
break ;
183
226
}
184
227
185
- drafted.push_back (cur_p.data [0 ].id );
228
+ // drafted token
229
+ const llama_token id = cur_p.data [0 ].id ;
230
+
231
+ drafted.push_back (id);
186
232
++n_drafted;
187
233
188
- if (i < n_draft - 1 ) {
189
- // evaluate the drafted token on the draft model
190
- llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
191
- ++n_past_cur;
234
+ // no need to evaluate the last drafted token, since we won't use the result
235
+ if (i == n_draft - 1 ) {
236
+ break ;
237
+ }
238
+
239
+ // evaluate the drafted token on the draft model
240
+ llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
241
+ ++n_past_cur;
242
+
243
+ if (grammar_dft != NULL ) {
244
+ llama_grammar_accept_token (ctx_dft, grammar_dft, id);
192
245
}
193
246
}
194
247
195
248
// evaluate the target model on the drafted tokens
196
249
llama_eval (ctx_tgt, drafted.data (), drafted.size (), n_past_tgt, params.n_threads );
197
250
++n_past_tgt;
198
251
252
+ // the first token is always proposed by the traget model before the speculation loop
199
253
drafted.erase (drafted.begin ());
200
254
}
201
255
@@ -226,6 +280,10 @@ int main(int argc, char ** argv) {
226
280
llama_free (ctx_dft);
227
281
llama_free_model (model_dft);
228
282
283
+ if (grammar_dft != NULL ) {
284
+ llama_grammar_free (grammar_dft);
285
+ llama_grammar_free (grammar_tgt);
286
+ }
229
287
llama_backend_free ();
230
288
231
289
fprintf (stderr, " \n\n " );
0 commit comments