Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ by h g wells i the time traveller for so it will be convenient to speak of him w
-----------------
```

### pre-trained lm model

This model was trained for 7 epochs using the full text of The Time Machine novel.
[model](https://cpp-transformer-1252366230.cos.ap-beijing.myqcloud.com/lm/checkpoint_20250609_123600_7.bin)

## legacy version

[v1](https://github.com/freelw/cpp-transformer/tree/v1_freeze_20250529)
Expand Down
1 change: 1 addition & 0 deletions dataloaders/language_model/lm_dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void LMDataLoader::get_token_ids(
}

int token_ids_size = std::min((int)token_ids.size(), max_token_ids_size);
std::cout << "token_ids_size : " << token_ids_size << std::endl;

for (size_t i = 0; i < token_ids_size; ++i) {
std::vector<uint> src_step_tokens;
Expand Down
8 changes: 6 additions & 2 deletions lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ int main(int argc, char* argv[]) {
int gpu = 1;
int max_words_cnt = 256;
float lr = 0.001f;
int lm_predict_cnt = LM_PREDICT_CNT;
std::string checkpoint;
std::string corpus = TIMEMACHINE_RESOURCE_NAME;

while ((opt = getopt(argc, argv, "f:c:e:l:b:g:m:")) != -1) {
while ((opt = getopt(argc, argv, "f:c:e:l:b:g:m:p:")) != -1) {
switch (opt) {
case 'f':
corpus = optarg;
Expand All @@ -128,6 +129,9 @@ int main(int argc, char* argv[]) {
case 'm':
max_words_cnt = atoi(optarg);
break;
case 'p':
lm_predict_cnt = atoi(optarg);
break;
default:
std::cerr << "Usage: " << argv[0]
<< " -f <corpus> -c <checpoint> -e <epochs>" << std::endl;
Expand Down Expand Up @@ -249,7 +253,7 @@ int main(int argc, char* argv[]) {
float* res_buffer = static_cast<float*>(::malloc(
res->get_tensor()->size()
));
for (int i = 0; i < LM_PREDICT_CNT; ++i) {
for (int i = 0; i < lm_predict_cnt; ++i) {
for (int j = 0; j < num_steps; ++j) {
tgt_token_ids_buffer[j] = src_token_ids[j];
}
Expand Down
3 changes: 2 additions & 1 deletion test_lm.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
the time machine
the time machine
the fire burned brightly