|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "config/config.h" |
| 4 | +#include <string> |
| 5 | +#include <vector> |
| 6 | +#include <map> |
| 7 | +#include <unordered_map> |
| 8 | +#include <set> |
| 9 | +#include <fstream> |
| 10 | +#include <sstream> |
| 11 | +#include <algorithm> |
| 12 | +#include <iostream> |
| 13 | +#include <random> |
| 14 | +#include <stdexcept> |
| 15 | + |
| 16 | +// DataLoader handles BPE tokenization, corpus loading, and batch generation. |
| 17 | +struct DataLoader |
| 18 | +{ |
| 19 | + std::vector<std::string> vocab; // id to token string |
| 20 | + std::map<std::string, int> token_to_id; // token string to id |
| 21 | + std::vector<std::pair<int, int>> merges; // merge rules in training order |
| 22 | + std::map<std::pair<int,int>, int> merge_rank; // pair to training step index |
| 23 | + int base_vocab_size{0}; // number of single-char tokens |
| 24 | + int vocab_size{0}; |
| 25 | + |
| 26 | + std::vector<int> train_data; |
| 27 | + std::vector<int> val_data; |
| 28 | + |
| 29 | + // Load text file and train BPE tokenizer on it. |
| 30 | + // Splits into train and validation sets. |
| 31 | + void load(const std::string &path, |
| 32 | + int target_vocab = BPE_VOCAB_SIZE, |
| 33 | + double train_split = TRAIN_SPLIT) |
| 34 | + { |
| 35 | + std::ifstream f(path); |
| 36 | + if (!f.is_open()) |
| 37 | + throw std::runtime_error("[DataLoader] Cannot open file: " + path); |
| 38 | + |
| 39 | + std::ostringstream ss; |
| 40 | + ss << f.rdbuf(); |
| 41 | + std::string text = ss.str(); |
| 42 | + if (text.empty()) |
| 43 | + throw std::runtime_error("[DataLoader] File is empty: " + path); |
| 44 | + |
| 45 | + std::cout << "[BPE] Text length: " << text.size() << " characters\n"; |
| 46 | + std::cout << "[BPE] Target vocab size: " << target_vocab << "\n"; |
| 47 | + std::cout.flush(); |
| 48 | + |
| 49 | + std::vector<int> data = train_bpe(text, target_vocab); |
| 50 | + |
| 51 | + int n = (int)(train_split * (double)data.size()); |
| 52 | + train_data = std::vector<int>(data.begin(), data.begin() + n); |
| 53 | + val_data = std::vector<int>(data.begin() + n, data.end()); |
| 54 | + |
| 55 | + if ((int)train_data.size() <= BLOCK_SIZE || |
| 56 | + (int)val_data.size() <= BLOCK_SIZE) |
| 57 | + throw std::runtime_error( |
| 58 | + "[DataLoader] Dataset too small for BLOCK_SIZE=" + |
| 59 | + std::to_string(BLOCK_SIZE)); |
| 60 | + |
| 61 | + std::cout << "[DATA] Total tokens : " << data.size() << "\n"; |
| 62 | + std::cout << "[DATA] Train tokens : " << train_data.size() << "\n"; |
| 63 | + std::cout << "[DATA] Val tokens : " << val_data.size() << "\n"; |
| 64 | + } |
| 65 | + |
| 66 | + // Encode a string to BPE token ids. |
| 67 | + // Characters not found in base vocab are skipped silently. |
| 68 | + std::vector<int> encode(const std::string &text) const |
| 69 | + { |
| 70 | + return apply_merges(base_encode(text)); |
| 71 | + } |
| 72 | + |
| 73 | + // Decode BPE token ids back to the original string. |
| 74 | + std::string decode(const std::vector<int> &ids) const |
| 75 | + { |
| 76 | + std::string out; |
| 77 | + for (int id : ids) |
| 78 | + if (id >= 0 && id < (int)vocab.size()) |
| 79 | + out += vocab[id]; |
| 80 | + return out; |
| 81 | + } |
| 82 | + |
| 83 | + // Get a random batch of (input, target) token pairs. |
| 84 | + // Input is [batch_size * block_size] tokens. |
| 85 | + // Target is the next token for each input position. |
| 86 | + std::pair<std::vector<int>, std::vector<int>> |
| 87 | + get_batch(const std::string &split, int batch_size, int block_size, |
| 88 | + std::mt19937 &rng) const |
| 89 | + { |
| 90 | + const std::vector<int> &d = (split == "train") ? train_data : val_data; |
| 91 | + std::uniform_int_distribution<int> dist(0, (int)d.size() - block_size - 1); |
| 92 | + |
| 93 | + std::vector<int> x(batch_size * block_size); |
| 94 | + std::vector<int> y(batch_size * block_size); |
| 95 | + |
| 96 | + for (int b = 0; b < batch_size; ++b) { |
| 97 | + int start = dist(rng); |
| 98 | + for (int t = 0; t < block_size; ++t) { |
| 99 | + x[b * block_size + t] = d[start + t]; |
| 100 | + y[b * block_size + t] = d[start + t + 1]; |
| 101 | + } |
| 102 | + } |
| 103 | + return {x, y}; |
| 104 | + } |
| 105 | + |
| 106 | +private: |
| 107 | + |
| 108 | + // Convert a string to char-level token ids. |
| 109 | + std::vector<int> base_encode(const std::string &text) const |
| 110 | + { |
| 111 | + std::vector<int> ids; |
| 112 | + ids.reserve(text.size()); |
| 113 | + for (char c : text) { |
| 114 | + auto it = token_to_id.find(std::string(1, c)); |
| 115 | + if (it != token_to_id.end()) |
| 116 | + ids.push_back(it->second); |
| 117 | + } |
| 118 | + return ids; |
| 119 | + } |
| 120 | + |
| 121 | + // Apply all BPE merge rules in training order to a token sequence. |
| 122 | + std::vector<int> apply_merges(std::vector<int> ids) const |
| 123 | + { |
| 124 | + for (int rank = 0; rank < (int)merges.size(); ++rank) { |
| 125 | + int left = merges[rank].first; |
| 126 | + int right = merges[rank].second; |
| 127 | + int new_id = base_vocab_size + rank; |
| 128 | + |
| 129 | + std::vector<int> out; |
| 130 | + out.reserve(ids.size()); |
| 131 | + int i = 0; |
| 132 | + while (i < (int)ids.size()) { |
| 133 | + if (i + 1 < (int)ids.size() && |
| 134 | + ids[i] == left && ids[i+1] == right) |
| 135 | + { |
| 136 | + out.push_back(new_id); |
| 137 | + i += 2; |
| 138 | + } else { |
| 139 | + out.push_back(ids[i]); |
| 140 | + ++i; |
| 141 | + } |
| 142 | + } |
| 143 | + ids = std::move(out); |
| 144 | + } |
| 145 | + return ids; |
| 146 | + } |
| 147 | + |
| 148 | + // Train BPE on the input text. |
| 149 | + // Returns the final BPE-encoded token sequence for the full corpus. |
| 150 | + std::vector<int> train_bpe(const std::string &text, int target_vocab) |
| 151 | + { |
| 152 | + // Build base vocabulary from all unique characters. |
| 153 | + std::set<char> chars(text.begin(), text.end()); |
| 154 | + for (char c : chars) { |
| 155 | + std::string s(1, c); |
| 156 | + token_to_id[s] = (int)vocab.size(); |
| 157 | + vocab.push_back(s); |
| 158 | + } |
| 159 | + base_vocab_size = (int)vocab.size(); |
| 160 | + |
| 161 | + if (target_vocab <= base_vocab_size) { |
| 162 | + vocab_size = base_vocab_size; |
| 163 | + return base_encode(text); |
| 164 | + } |
| 165 | + |
| 166 | + // Encode full text as base character-level tokens. |
| 167 | + std::vector<int> ids = base_encode(text); |
| 168 | + |
| 169 | + int num_merges = target_vocab - base_vocab_size; |
| 170 | + merges.reserve(num_merges); |
| 171 | + std::cout << "[BPE] Running " << num_merges << " merges...\n"; |
| 172 | + std::cout.flush(); |
| 173 | + |
| 174 | + for (int step = 0; step < num_merges; ++step) |
| 175 | + { |
| 176 | + // Count frequency of all adjacent token pairs. |
| 177 | + std::unordered_map<long long, int> counts; |
| 178 | + counts.reserve(ids.size()); |
| 179 | + for (int i = 0; i + 1 < (int)ids.size(); ++i) { |
| 180 | + long long key = ((long long)ids[i] << 32) | (unsigned int)ids[i+1]; |
| 181 | + ++counts[key]; |
| 182 | + } |
| 183 | + |
| 184 | + if (counts.empty()) break; |
| 185 | + |
| 186 | + // Find the most frequent pair. |
| 187 | + long long best_key = 0; |
| 188 | + int best_cnt = -1; |
| 189 | + for (auto &kv : counts) { |
| 190 | + if (kv.second > best_cnt) { |
| 191 | + best_cnt = kv.second; |
| 192 | + best_key = kv.first; |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + int left = (int)(best_key >> 32); |
| 197 | + int right = (int)(best_key & 0xFFFFFFFFLL); |
| 198 | + |
| 199 | + // Create a new token for this merge. |
| 200 | + int new_id = (int)vocab.size(); |
| 201 | + std::string new_tok = vocab[left] + vocab[right]; |
| 202 | + vocab.push_back(new_tok); |
| 203 | + token_to_id[new_tok] = new_id; |
| 204 | + merges.push_back({left, right}); |
| 205 | + merge_rank[{left, right}] = step; |
| 206 | + |
| 207 | + // Apply the merge everywhere in the sequence. |
| 208 | + std::vector<int> out; |
| 209 | + out.reserve(ids.size()); |
| 210 | + int i = 0; |
| 211 | + while (i < (int)ids.size()) { |
| 212 | + if (i + 1 < (int)ids.size() && |
| 213 | + ids[i] == left && ids[i+1] == right) |
| 214 | + { |
| 215 | + out.push_back(new_id); |
| 216 | + i += 2; |
| 217 | + } else { |
| 218 | + out.push_back(ids[i]); |
| 219 | + ++i; |
| 220 | + } |
| 221 | + } |
| 222 | + ids = std::move(out); |
| 223 | + |
| 224 | + // Print progress every 500 merges. |
| 225 | + if ((step + 1) % 500 == 0 || step + 1 == num_merges) { |
| 226 | + std::cout << "[BPE] step " << (step + 1) << "/" << num_merges |
| 227 | + << " vocab=" << (int)vocab.size() |
| 228 | + << " seq=" << ids.size() << "\n"; |
| 229 | + std::cout.flush(); |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + vocab_size = (int)vocab.size(); |
| 234 | + std::cout << "[BPE] Done. Final vocab size: " << vocab_size << "\n"; |
| 235 | + return ids; |
| 236 | + } |
| 237 | +}; |
0 commit comments