8
8
9
9
int main (int argc, char ** argv) {
10
10
gpt_params params;
11
- llama_sampling_params & sparams = params.sampling_params ;
12
- params.seed = 42 ;
13
- params.n_threads = 4 ;
14
- sparams.repeat_last_n = 64 ;
11
+
15
12
params.prompt = " The quick brown fox" ;
16
13
17
14
if (!gpt_params_parse (argc, argv, params)) {
@@ -25,56 +22,49 @@ int main(int argc, char ** argv) {
25
22
}
26
23
27
24
auto n_past = 0 ;
28
- auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n , 0 );
25
+
26
+ std::string result0;
27
+ std::string result1;
29
28
30
29
// init
31
30
llama_model * model;
32
31
llama_context * ctx;
33
32
34
- std::tie (model, ctx) = llama_init_from_gpt_params ( params );
35
- if (model == nullptr ) {
36
- return 1 ;
37
- }
38
- if (ctx == nullptr ) {
39
- llama_free_model (model);
33
+ std::tie (model, ctx) = llama_init_from_gpt_params (params);
34
+ if (model == nullptr || ctx == nullptr ) {
35
+ fprintf (stderr, " %s : failed to init\n " , __func__);
40
36
return 1 ;
41
37
}
38
+
39
+ // tokenize prompt
42
40
auto tokens = llama_tokenize (ctx, params.prompt , true );
43
- auto n_prompt_tokens = tokens.size ();
44
- if (n_prompt_tokens < 1 ) {
45
- fprintf (stderr, " %s : failed to tokenize prompt\n " , __func__);
46
- llama_free (ctx);
47
- llama_free_model (model);
48
- return 1 ;
49
- }
50
41
51
42
// evaluate prompt
52
- llama_decode (ctx, llama_batch_get_one (tokens.data (), n_prompt_tokens, n_past, 0 ));
43
+ llama_decode (ctx, llama_batch_get_one (tokens.data (), tokens.size (), n_past, 0 ));
44
+ n_past += tokens.size ();
53
45
54
- last_n_tokens_data.insert (last_n_tokens_data.end (), tokens.data (), tokens.data () + n_prompt_tokens);
55
- n_past += n_prompt_tokens;
56
-
57
- const size_t state_size = llama_get_state_size (ctx);
58
- uint8_t * state_mem = new uint8_t [state_size];
59
-
60
- // Save state (rng, logits, embedding and kv_cache) to file
46
+ // save state (rng, logits, embedding and kv_cache) to file
61
47
{
62
- FILE *fp_write = fopen (" dump_state.bin" , " wb" );
63
- llama_copy_state_data (ctx, state_mem); // could also copy directly to memory mapped file
64
- fwrite (state_mem, 1 , state_size, fp_write);
65
- fclose (fp_write);
48
+ std::vector<uint8_t > state_mem (llama_get_state_size (ctx));
49
+
50
+ {
51
+ FILE *fp_write = fopen (" dump_state.bin" , " wb" );
52
+ llama_copy_state_data (ctx, state_mem.data ()); // could also copy directly to memory mapped file
53
+ fwrite (state_mem.data (), 1 , state_mem.size (), fp_write);
54
+ fclose (fp_write);
55
+ }
66
56
}
67
57
68
58
// save state (last tokens)
69
- const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
70
59
const auto n_past_saved = n_past;
71
60
72
61
// first run
73
- printf (" \n %s" , params.prompt .c_str ());
62
+ printf (" \n first run: %s" , params.prompt .c_str ());
74
63
75
64
for (auto i = 0 ; i < params.n_predict ; i++) {
76
65
auto * logits = llama_get_logits (ctx);
77
66
auto n_vocab = llama_n_vocab (model);
67
+
78
68
std::vector<llama_token_data> candidates;
79
69
candidates.reserve (n_vocab);
80
70
for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
@@ -83,9 +73,10 @@ int main(int argc, char ** argv) {
83
73
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
84
74
auto next_token = llama_sample_token (ctx, &candidates_p);
85
75
auto next_token_str = llama_token_to_piece (ctx, next_token);
86
- last_n_tokens_data.push_back (next_token);
87
76
88
77
printf (" %s" , next_token_str.c_str ());
78
+ result0 += next_token_str;
79
+
89
80
if (llama_decode (ctx, llama_batch_get_one (&next_token, 1 , n_past, 0 ))) {
90
81
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
91
82
llama_free (ctx);
@@ -103,32 +94,28 @@ int main(int argc, char ** argv) {
103
94
// make new context
104
95
auto * ctx2 = llama_new_context_with_model (model, llama_context_params_from_gpt_params (params));
105
96
106
- // Load state (rng, logits, embedding and kv_cache) from file
97
+ printf (" \n second run: %s" , params.prompt .c_str ());
98
+
99
+ // load state (rng, logits, embedding and kv_cache) from file
107
100
{
108
- FILE *fp_read = fopen (" dump_state.bin" , " rb" );
109
- if (state_size != llama_get_state_size (ctx2)) {
110
- fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
111
- llama_free (ctx2);
112
- llama_free_model (model);
113
- return 1 ;
114
- }
101
+ std::vector<uint8_t > state_mem (llama_get_state_size (ctx2));
115
102
116
- const size_t ret = fread (state_mem, 1 , state_size, fp_read);
117
- if (ret != state_size) {
103
+ FILE * fp_read = fopen (" dump_state.bin" , " rb" );
104
+
105
+ const size_t ret = fread (state_mem.data (), 1 , state_mem.size (), fp_read);
106
+ if (ret != state_mem.size ()) {
118
107
fprintf (stderr, " \n %s : failed to read state\n " , __func__);
119
108
llama_free (ctx2);
120
109
llama_free_model (model);
121
110
return 1 ;
122
111
}
123
112
124
- llama_set_state_data (ctx2, state_mem); // could also read directly from memory mapped file
113
+ llama_set_state_data (ctx2, state_mem.data ());
114
+
125
115
fclose (fp_read);
126
116
}
127
117
128
- delete[] state_mem;
129
-
130
118
// restore state (last tokens)
131
- last_n_tokens_data = last_n_tokens_data_saved;
132
119
n_past = n_past_saved;
133
120
134
121
// second run
@@ -143,10 +130,11 @@ int main(int argc, char ** argv) {
143
130
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
144
131
auto next_token = llama_sample_token (ctx2, &candidates_p);
145
132
auto next_token_str = llama_token_to_piece (ctx2, next_token);
146
- last_n_tokens_data.push_back (next_token);
147
133
148
134
printf (" %s" , next_token_str.c_str ());
149
- if (llama_decode (ctx, llama_batch_get_one (&next_token, 1 , n_past, 0 ))) {
135
+ result1 += next_token_str;
136
+
137
+ if (llama_decode (ctx2, llama_batch_get_one (&next_token, 1 , n_past, 0 ))) {
150
138
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
151
139
llama_free (ctx2);
152
140
llama_free_model (model);
@@ -155,10 +143,17 @@ int main(int argc, char ** argv) {
155
143
n_past += 1 ;
156
144
}
157
145
158
- printf (" \n\n " );
146
+ printf (" \n " );
159
147
160
148
llama_free (ctx2);
161
149
llama_free_model (model);
162
150
151
+ if (result0 != result1) {
152
+ fprintf (stderr, " \n %s : error : the 2 generations are different\n " , __func__);
153
+ return 1 ;
154
+ }
155
+
156
+ fprintf (stderr, " \n %s : success\n " , __func__);
157
+
163
158
return 0 ;
164
159
}
0 commit comments