Skip to content

Commit 1142013

Browse files
authored
save-load-state : fix example + add ci test (ggml-org#3655)
* save-load-state : fix example (close ggml-org#3606) * ci : add test for save-load-state example ggml-ci
1 parent 5fe268a commit 1142013

File tree

2 files changed

+51
-50
lines changed

2 files changed

+51
-50
lines changed

ci/run.sh

+6
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ function gg_run_open_llama_3b_v2 {
208208
(time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
209209
(time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
210210

211+
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
212+
211213
function check_ppl {
212214
qnt="$1"
213215
ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
@@ -296,6 +298,7 @@ function gg_sum_open_llama_3b_v2 {
296298
gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)"
297299
gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)"
298300
gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)"
301+
gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)"
299302
gg_printf '- shakespeare (f16):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-f16.log)"
300303
gg_printf '- shakespeare (f16 lora):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-lora-f16.log)"
301304
gg_printf '- shakespeare (q8_0):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log)"
@@ -382,6 +385,8 @@ function gg_run_open_llama_7b_v2 {
382385
(time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
383386
(time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
384387

388+
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
389+
385390
function check_ppl {
386391
qnt="$1"
387392
ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
@@ -470,6 +475,7 @@ function gg_sum_open_llama_7b_v2 {
470475
gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)"
471476
gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)"
472477
gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)"
478+
gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)"
473479
gg_printf '- shakespeare (f16):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-f16.log)"
474480
gg_printf '- shakespeare (f16 lora):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-lora-f16.log)"
475481
#gg_printf '- shakespeare (q8_0):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log)"

examples/save-load-state/save-load-state.cpp

+45-50
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
int main(int argc, char ** argv) {
1010
gpt_params params;
11-
llama_sampling_params & sparams = params.sampling_params;
12-
params.seed = 42;
13-
params.n_threads = 4;
14-
sparams.repeat_last_n = 64;
11+
1512
params.prompt = "The quick brown fox";
1613

1714
if (!gpt_params_parse(argc, argv, params)) {
@@ -25,56 +22,49 @@ int main(int argc, char ** argv) {
2522
}
2623

2724
auto n_past = 0;
28-
auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n, 0);
25+
26+
std::string result0;
27+
std::string result1;
2928

3029
// init
3130
llama_model * model;
3231
llama_context * ctx;
3332

34-
std::tie(model, ctx) = llama_init_from_gpt_params( params );
35-
if (model == nullptr) {
36-
return 1;
37-
}
38-
if (ctx == nullptr) {
39-
llama_free_model(model);
33+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
34+
if (model == nullptr || ctx == nullptr) {
35+
fprintf(stderr, "%s : failed to init\n", __func__);
4036
return 1;
4137
}
38+
39+
// tokenize prompt
4240
auto tokens = llama_tokenize(ctx, params.prompt, true);
43-
auto n_prompt_tokens = tokens.size();
44-
if (n_prompt_tokens < 1) {
45-
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
46-
llama_free(ctx);
47-
llama_free_model(model);
48-
return 1;
49-
}
5041

5142
// evaluate prompt
52-
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0));
43+
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0));
44+
n_past += tokens.size();
5345

54-
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
55-
n_past += n_prompt_tokens;
56-
57-
const size_t state_size = llama_get_state_size(ctx);
58-
uint8_t * state_mem = new uint8_t[state_size];
59-
60-
// Save state (rng, logits, embedding and kv_cache) to file
46+
// save state (rng, logits, embedding and kv_cache) to file
6147
{
62-
FILE *fp_write = fopen("dump_state.bin", "wb");
63-
llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
64-
fwrite(state_mem, 1, state_size, fp_write);
65-
fclose(fp_write);
48+
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
49+
50+
{
51+
FILE *fp_write = fopen("dump_state.bin", "wb");
52+
llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file
53+
fwrite(state_mem.data(), 1, state_mem.size(), fp_write);
54+
fclose(fp_write);
55+
}
6656
}
6757

6858
// save state (last tokens)
69-
const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
7059
const auto n_past_saved = n_past;
7160

7261
// first run
73-
printf("\n%s", params.prompt.c_str());
62+
printf("\nfirst run: %s", params.prompt.c_str());
7463

7564
for (auto i = 0; i < params.n_predict; i++) {
7665
auto * logits = llama_get_logits(ctx);
7766
auto n_vocab = llama_n_vocab(model);
67+
7868
std::vector<llama_token_data> candidates;
7969
candidates.reserve(n_vocab);
8070
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -83,9 +73,10 @@ int main(int argc, char ** argv) {
8373
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
8474
auto next_token = llama_sample_token(ctx, &candidates_p);
8575
auto next_token_str = llama_token_to_piece(ctx, next_token);
86-
last_n_tokens_data.push_back(next_token);
8776

8877
printf("%s", next_token_str.c_str());
78+
result0 += next_token_str;
79+
8980
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
9081
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
9182
llama_free(ctx);
@@ -103,32 +94,28 @@ int main(int argc, char ** argv) {
10394
// make new context
10495
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
10596

106-
// Load state (rng, logits, embedding and kv_cache) from file
97+
printf("\nsecond run: %s", params.prompt.c_str());
98+
99+
// load state (rng, logits, embedding and kv_cache) from file
107100
{
108-
FILE *fp_read = fopen("dump_state.bin", "rb");
109-
if (state_size != llama_get_state_size(ctx2)) {
110-
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
111-
llama_free(ctx2);
112-
llama_free_model(model);
113-
return 1;
114-
}
101+
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
115102

116-
const size_t ret = fread(state_mem, 1, state_size, fp_read);
117-
if (ret != state_size) {
103+
FILE * fp_read = fopen("dump_state.bin", "rb");
104+
105+
const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read);
106+
if (ret != state_mem.size()) {
118107
fprintf(stderr, "\n%s : failed to read state\n", __func__);
119108
llama_free(ctx2);
120109
llama_free_model(model);
121110
return 1;
122111
}
123112

124-
llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file
113+
llama_set_state_data(ctx2, state_mem.data());
114+
125115
fclose(fp_read);
126116
}
127117

128-
delete[] state_mem;
129-
130118
// restore state (last tokens)
131-
last_n_tokens_data = last_n_tokens_data_saved;
132119
n_past = n_past_saved;
133120

134121
// second run
@@ -143,10 +130,11 @@ int main(int argc, char ** argv) {
143130
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
144131
auto next_token = llama_sample_token(ctx2, &candidates_p);
145132
auto next_token_str = llama_token_to_piece(ctx2, next_token);
146-
last_n_tokens_data.push_back(next_token);
147133

148134
printf("%s", next_token_str.c_str());
149-
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
135+
result1 += next_token_str;
136+
137+
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) {
150138
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
151139
llama_free(ctx2);
152140
llama_free_model(model);
@@ -155,10 +143,17 @@ int main(int argc, char ** argv) {
155143
n_past += 1;
156144
}
157145

158-
printf("\n\n");
146+
printf("\n");
159147

160148
llama_free(ctx2);
161149
llama_free_model(model);
162150

151+
if (result0 != result1) {
152+
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
153+
return 1;
154+
}
155+
156+
fprintf(stderr, "\n%s : success\n", __func__);
157+
163158
return 0;
164159
}

0 commit comments

Comments
 (0)