@@ -25,13 +25,30 @@ using json = nlohmann::json;
2525
2626namespace tokenizers {
2727
28+ namespace {
29+ // Helper to extract token string from either string or object format
30+ std::string extract_token_string (const json& token_json) {
31+ if (token_json.is_string ()) {
32+ return token_json.get <std::string>();
33+ } else if (token_json.is_object () && token_json.contains (" content" )) {
34+ return token_json[" content" ].get <std::string>();
35+ }
36+ return " " ;
37+ };
38+ } // namespace
2839// -------------------------private method end-------------------------------
2940// -------------------------public method start-------------------------------
3041
3142Error HFTokenizer::load (const std::string& path) {
3243 // If this is a directory, look for tokenizer.json and tokenizer_config.json
3344 std::string model_json = path;
3445 std::string model_config_json = " " ;
46+ std::string special_tokens_map_json;
47+
48+ // Check if bos/eos found.
49+ bool bos_found = false ;
50+ bool eos_found = false ;
51+
3552 if (fs::is_directory (path)) {
3653 const fs::path root (path);
3754 model_json = (root / " tokenizer.json" ).string ();
@@ -43,6 +60,11 @@ Error HFTokenizer::load(const std::string& path) {
4360 if (fs::exists (model_config_json_path)) {
4461 model_config_json = model_config_json_path.string ();
4562 }
63+
64+ const auto special_tokens_map_json_path = root / " special_tokens_map.json" ;
65+ if (fs::exists (special_tokens_map_json_path)) {
66+ special_tokens_map_json = special_tokens_map_json_path.string ();
67+ }
4668 }
4769
4870 // Load the tokenizer.json file
@@ -63,7 +85,6 @@ Error HFTokenizer::load(const std::string& path) {
6385
6486 // Parse the special tokens
6587 try {
66- std::vector<std::pair<std::string, std::uint64_t >> special_token_pairs;
6788 const auto & special_tokens = parsed_json.at (" added_tokens" );
6889 auto special_token_map_result = detail::build_token_map (
6990 special_tokens,
@@ -213,8 +234,37 @@ Error HFTokenizer::load(const std::string& path) {
213234 return Error::LoadFailure;
214235 }
215236
216- // If a tokenizer config file is found, parse it to look up the eos/bos tokens
217- if (!model_config_json.empty ()) {
237+ // Try special_tokens_map.json first
238+ std::string bos_token;
239+ std::string eos_token;
240+
241+ if (!special_tokens_map_json.empty ()) {
242+ std::ifstream special_file (special_tokens_map_json);
243+ if (special_file) {
244+ try {
245+ json special_tokens_json = json::parse (std::string (
246+ (std::istreambuf_iterator<char >(special_file)),
247+ std::istreambuf_iterator<char >()));
248+
249+ if (special_tokens_json.contains (" bos_token" )) {
250+ bos_token = extract_token_string (special_tokens_json[" bos_token" ]);
251+ }
252+ if (special_tokens_json.contains (" eos_token" )) {
253+ eos_token = extract_token_string (special_tokens_json[" eos_token" ]);
254+ }
255+
256+ TK_LOG (
257+ Info,
258+ " Loaded tokens from special_tokens_map.json: bos='%s', eos='%s'" ,
259+ bos_token.c_str (),
260+ eos_token.c_str ());
261+ } catch (const std::exception& e) {
262+ TK_LOG (Info, " Could not parse special_tokens_map.json: %s" , e.what ());
263+ }
264+ }
265+ }
266+ // Try tokenizer_config.json next
267+ if ((bos_token.empty () || eos_token.empty ()) && !model_config_json.empty ()) {
218268 // Load it and parse it as json
219269 std::ifstream config_file (model_config_json);
220270 if (!config_file) {
@@ -224,59 +274,62 @@ Error HFTokenizer::load(const std::string& path) {
224274 std::string config_contents (
225275 (std::istreambuf_iterator<char >(config_file)),
226276 std::istreambuf_iterator<char >());
227- json parsed_config_json;
228277 try {
229- parsed_config_json = json::parse (config_contents);
278+ json parsed_config_json = json::parse (config_contents);
279+ if (bos_token.empty () && parsed_config_json.contains (" bos_token" )) {
280+ bos_token = extract_token_string (parsed_config_json[" bos_token" ]);
281+ }
282+ if (eos_token.empty () && parsed_config_json.contains (" eos_token" )) {
283+ eos_token = extract_token_string (parsed_config_json[" eos_token" ]);
284+ }
285+ TK_LOG (
286+ Info,
287+ " Loaded tokens from tokenizer_config.json: bos='%s', eos='%s'" ,
288+ bos_token.c_str (),
289+ eos_token.c_str ());
230290 } catch (const std::exception& e) {
231291 TK_LOG (Error, " Error parsing model config json json file: %s" , e.what ());
232292 return Error::LoadFailure;
233293 }
294+ }
234295
235- // Pull out the token strings
236- try {
237- const std::string bos_token = parsed_config_json.contains (" bos_token" ) &&
238- !parsed_config_json[" bos_token" ].is_null ()
239- ? parsed_config_json[" bos_token" ].get <std::string>()
240- : " " ;
241-
242- const std::string eos_token = parsed_config_json.contains (" eos_token" ) &&
243- !parsed_config_json[" eos_token" ].is_null ()
244- ? parsed_config_json[" eos_token" ].get <std::string>()
245- : " " ;
246- const auto bos_res = special_token_map_->tryGetInteger (bos_token);
247- const auto eos_res = special_token_map_->tryGetInteger (eos_token);
248- if (!bos_res) {
249- TK_LOG (Error, " BOS token %s not in special tokens" , bos_token.c_str ());
250- return Error::LoadFailure;
251- }
252- if (!eos_res) {
253- TK_LOG (Error, " EOS token %s not in special tokens" , eos_token.c_str ());
254- return Error::LoadFailure;
255- }
256- bos_tok_ = *bos_res;
257- eos_tok_ = *eos_res;
258- } catch (const std::exception& e) {
259- TK_LOG (Error, " Could not eos/bos from tokenizer config: %s" , e.what ());
260- return Error::LoadFailure;
296+ // Try to extract the bos/eos tokens.
297+ if (!bos_token.empty () && !eos_token.empty ()) {
298+ auto bos_candidate = special_token_map_->tryGetInteger (bos_token);
299+ if (!bos_candidate) {
300+ TK_LOG (Info, " BOS token %s not in special tokens" , bos_token.c_str ());
301+ } else {
302+ bos_tok_ = *bos_candidate;
303+ bos_found = true ;
304+ }
305+
306+ auto eos_candidate = special_token_map_->tryGetInteger (eos_token);
307+ if (!eos_candidate) {
308+ TK_LOG (Info, " EOS token %s not in special tokens" , eos_token.c_str ());
309+ } else {
310+ eos_tok_ = *eos_candidate;
311+ eos_found = true ;
261312 }
262313 }
263314
264315 // Otherwise, make an educated guess with the following logic:
265316 // 1. Look for special tokens with "bos"/"begin" or "eos"/"end" in them
266317 // 2. Sub-qualify with the word "text" if needed
267318 // 3. If EOS found, but BOS is not (or vice versa), assume they are the same
268- else {
319+ if (!eos_found || !bos_found) {
269320 std::vector<std::string_view> bos_candidates;
270321 std::vector<std::string_view> eos_candidates;
271322 for (std::size_t token_idx = 0 ; token_idx < special_token_map_->size ();
272323 ++token_idx) {
273324 const auto [token, _] = special_token_map_->getElement (token_idx);
274- if (token.find (" bos" ) != std::string::npos ||
275- token.find (" begin" ) != std::string::npos) {
325+ if (!bos_found &&
326+ (token.find (" bos" ) != std::string::npos ||
327+ token.find (" begin" ) != std::string::npos)) {
276328 bos_candidates.push_back (token);
277329 }
278- if (token.find (" eos" ) != std::string::npos ||
279- token.find (" end" ) != std::string::npos) {
330+ if (!eos_found &&
331+ (token.find (" eos" ) != std::string::npos ||
332+ token.find (" end" ) != std::string::npos)) {
280333 eos_candidates.push_back (token);
281334 }
282335 }
@@ -300,14 +353,11 @@ Error HFTokenizer::load(const std::string& path) {
300353 }
301354 }
302355
303- // Use if a single candidate
304- bool bos_found = false ;
305- bool eos_found = false ;
306- if (bos_candidates.size () == 1 ) {
356+ if (!bos_found && bos_candidates.size () == 1 ) {
307357 bos_found = true ;
308358 bos_tok_ = *(special_token_map_->tryGetInteger (bos_candidates[0 ]));
309359 }
310- if (eos_candidates.size () == 1 ) {
360+ if (!eos_found && eos_candidates.size () == 1 ) {
311361 eos_found = true ;
312362 eos_tok_ = *(special_token_map_->tryGetInteger (eos_candidates[0 ]));
313363 }
0 commit comments