diff --git a/README.md b/README.md index 0133fa5..8ba93f6 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,11 @@ by h g wells i the time traveller for so it will be convenient to speak of him w ----------------- ``` +### pre-trained lm model + +This model was trained for 7 epochs using the full text of The Time Machine novel. +[model](https://cpp-transformer-1252366230.cos.ap-beijing.myqcloud.com/lm/checkpoint_20250609_123600_7.bin) + ## legacy version [v1](https://github.com/freelw/cpp-transformer/tree/v1_freeze_20250529) diff --git a/dataloaders/language_model/lm_dataloader.cpp b/dataloaders/language_model/lm_dataloader.cpp index 5143748..44cd6ef 100644 --- a/dataloaders/language_model/lm_dataloader.cpp +++ b/dataloaders/language_model/lm_dataloader.cpp @@ -32,6 +32,7 @@ void LMDataLoader::get_token_ids( } int token_ids_size = std::min((int)token_ids.size(), max_token_ids_size); + std::cout << "token_ids_size : " << token_ids_size << std::endl; for (size_t i = 0; i < token_ids_size; ++i) { std::vector src_step_tokens; diff --git a/lm.cpp b/lm.cpp index 4603179..e057163 100644 --- a/lm.cpp +++ b/lm.cpp @@ -102,10 +102,11 @@ int main(int argc, char* argv[]) { int gpu = 1; int max_words_cnt = 256; float lr = 0.001f; + int lm_predict_cnt = LM_PREDICT_CNT; std::string checkpoint; std::string corpus = TIMEMACHINE_RESOURCE_NAME; - while ((opt = getopt(argc, argv, "f:c:e:l:b:g:m:")) != -1) { + while ((opt = getopt(argc, argv, "f:c:e:l:b:g:m:p:")) != -1) { switch (opt) { case 'f': corpus = optarg; @@ -128,6 +129,9 @@ int main(int argc, char* argv[]) { case 'm': max_words_cnt = atoi(optarg); break; + case 'p': + lm_predict_cnt = atoi(optarg); + break; default: std::cerr << "Usage: " << argv[0] << " -f -c -e " << std::endl; @@ -249,7 +253,7 @@ int main(int argc, char* argv[]) { float* res_buffer = static_cast(::malloc( res->get_tensor()->size() )); - for (int i = 0; i < LM_PREDICT_CNT; ++i) { + for (int i = 0; i < lm_predict_cnt; ++i) { for (int j = 0; j < num_steps; ++j) { tgt_token_ids_buffer[j] = src_token_ids[j]; } diff --git a/test_lm.txt b/test_lm.txt index 3592ed8..a1e6558 100644 --- a/test_lm.txt +++ b/test_lm.txt @@ -1 +1,2 @@ -the time machine \ No newline at end of file +the time machine +the fire burned brightly