Skip to content

Commit 33fd964

Browse files
committed
[tokenizers][PR] Parse special_tokens_map.json
Add functionality to hf_tokenizer to parse special_tokens_map.json, which contains the source of truth for which bos/eos to use. Differential Revision: [D84878533](https://our.internmc.facebook.com/intern/diff/D84878533/) ghstack-source-id: 317063266 Pull Request resolved: #145
1 parent 2cb27dd commit 33fd964

File tree

7 files changed

+338
-42
lines changed

7 files changed

+338
-42
lines changed

include/pytorch/tokenizers/bpe_tokenizer_base.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ inline Result<std::unique_ptr<IRegex>> build_special_token_regex(
122122
if (special_pattern.empty()) {
123123
return static_cast<std::unique_ptr<IRegex>>(nullptr);
124124
}
125-
return create_regex(special_pattern);
125+
// Wrap pattern in parentheses for proper grouping
126+
return create_regex("(" + special_pattern + ")");
126127
}
127128

128129
class BPETokenizerBase : public Tokenizer {

src/hf_tokenizer.cpp

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,30 @@ using json = nlohmann::json;
2525

2626
namespace tokenizers {
2727

28+
namespace {
29+
// Helper to extract token string from either string or object format
30+
std::string extract_token_string(const json& token_json) {
31+
if (token_json.is_string()) {
32+
return token_json.get<std::string>();
33+
} else if (token_json.is_object() && token_json.contains("content")) {
34+
return token_json["content"].get<std::string>();
35+
}
36+
return "";
37+
};
38+
} // namespace
2839
// -------------------------private method end-------------------------------
2940
// -------------------------public method start-------------------------------
3041

3142
Error HFTokenizer::load(const std::string& path) {
3243
// If this is a directory, look for tokenizer.json and tokenizer_config.json
3344
std::string model_json = path;
3445
std::string model_config_json = "";
46+
std::string special_tokens_map_json;
47+
48+
// Check if bos/eos found.
49+
bool bos_found = false;
50+
bool eos_found = false;
51+
3552
if (fs::is_directory(path)) {
3653
const fs::path root(path);
3754
model_json = (root / "tokenizer.json").string();
@@ -43,6 +60,11 @@ Error HFTokenizer::load(const std::string& path) {
4360
if (fs::exists(model_config_json_path)) {
4461
model_config_json = model_config_json_path.string();
4562
}
63+
64+
const auto special_tokens_map_json_path = root / "special_tokens_map.json";
65+
if (fs::exists(special_tokens_map_json_path)) {
66+
special_tokens_map_json = special_tokens_map_json_path.string();
67+
}
4668
}
4769

4870
// Load the tokenizer.json file
@@ -63,7 +85,6 @@ Error HFTokenizer::load(const std::string& path) {
6385

6486
// Parse the special tokens
6587
try {
66-
std::vector<std::pair<std::string, std::uint64_t>> special_token_pairs;
6788
const auto& special_tokens = parsed_json.at("added_tokens");
6889
auto special_token_map_result = detail::build_token_map(
6990
special_tokens,
@@ -213,8 +234,37 @@ Error HFTokenizer::load(const std::string& path) {
213234
return Error::LoadFailure;
214235
}
215236

216-
// If a tokenizer config file is found, parse it to look up the eos/bos tokens
217-
if (!model_config_json.empty()) {
237+
// Try special_tokens_map.json first
238+
std::string bos_token;
239+
std::string eos_token;
240+
241+
if (!special_tokens_map_json.empty()) {
242+
std::ifstream special_file(special_tokens_map_json);
243+
if (special_file) {
244+
try {
245+
json special_tokens_json = json::parse(std::string(
246+
(std::istreambuf_iterator<char>(special_file)),
247+
std::istreambuf_iterator<char>()));
248+
249+
if (special_tokens_json.contains("bos_token")) {
250+
bos_token = extract_token_string(special_tokens_json["bos_token"]);
251+
}
252+
if (special_tokens_json.contains("eos_token")) {
253+
eos_token = extract_token_string(special_tokens_json["eos_token"]);
254+
}
255+
256+
TK_LOG(
257+
Info,
258+
"Loaded tokens from special_tokens_map.json: bos='%s', eos='%s'",
259+
bos_token.c_str(),
260+
eos_token.c_str());
261+
} catch (const std::exception& e) {
262+
TK_LOG(Info, "Could not parse special_tokens_map.json: %s", e.what());
263+
}
264+
}
265+
}
266+
// Try tokenizer_config.json next
267+
if ((bos_token.empty() || eos_token.empty()) && !model_config_json.empty()) {
218268
// Load it and parse it as json
219269
std::ifstream config_file(model_config_json);
220270
if (!config_file) {
@@ -224,59 +274,62 @@ Error HFTokenizer::load(const std::string& path) {
224274
std::string config_contents(
225275
(std::istreambuf_iterator<char>(config_file)),
226276
std::istreambuf_iterator<char>());
227-
json parsed_config_json;
228277
try {
229-
parsed_config_json = json::parse(config_contents);
278+
json parsed_config_json = json::parse(config_contents);
279+
if (bos_token.empty() && parsed_config_json.contains("bos_token")) {
280+
bos_token = extract_token_string(parsed_config_json["bos_token"]);
281+
}
282+
if (eos_token.empty() && parsed_config_json.contains("eos_token")) {
283+
eos_token = extract_token_string(parsed_config_json["eos_token"]);
284+
}
285+
TK_LOG(
286+
Info,
287+
"Loaded tokens from tokenizer_config.json: bos='%s', eos='%s'",
288+
bos_token.c_str(),
289+
eos_token.c_str());
230290
} catch (const std::exception& e) {
231291
TK_LOG(Error, "Error parsing model config json json file: %s", e.what());
232292
return Error::LoadFailure;
233293
}
294+
}
234295

235-
// Pull out the token strings
236-
try {
237-
const std::string bos_token = parsed_config_json.contains("bos_token") &&
238-
!parsed_config_json["bos_token"].is_null()
239-
? parsed_config_json["bos_token"].get<std::string>()
240-
: "";
241-
242-
const std::string eos_token = parsed_config_json.contains("eos_token") &&
243-
!parsed_config_json["eos_token"].is_null()
244-
? parsed_config_json["eos_token"].get<std::string>()
245-
: "";
246-
const auto bos_res = special_token_map_->tryGetInteger(bos_token);
247-
const auto eos_res = special_token_map_->tryGetInteger(eos_token);
248-
if (!bos_res) {
249-
TK_LOG(Error, "BOS token %s not in special tokens", bos_token.c_str());
250-
return Error::LoadFailure;
251-
}
252-
if (!eos_res) {
253-
TK_LOG(Error, "EOS token %s not in special tokens", eos_token.c_str());
254-
return Error::LoadFailure;
255-
}
256-
bos_tok_ = *bos_res;
257-
eos_tok_ = *eos_res;
258-
} catch (const std::exception& e) {
259-
TK_LOG(Error, "Could not eos/bos from tokenizer config: %s", e.what());
260-
return Error::LoadFailure;
296+
// Try to extract the bos/eos tokens.
297+
if (!bos_token.empty() && !eos_token.empty()) {
298+
auto bos_candidate = special_token_map_->tryGetInteger(bos_token);
299+
if (!bos_candidate) {
300+
TK_LOG(Info, "BOS token %s not in special tokens", bos_token.c_str());
301+
} else {
302+
bos_tok_ = *bos_candidate;
303+
bos_found = true;
304+
}
305+
306+
auto eos_candidate = special_token_map_->tryGetInteger(eos_token);
307+
if (!eos_candidate) {
308+
TK_LOG(Info, "EOS token %s not in special tokens", eos_token.c_str());
309+
} else {
310+
eos_tok_ = *eos_candidate;
311+
eos_found = true;
261312
}
262313
}
263314

264315
// Otherwise, make an educated guess with the following logic:
265316
// 1. Look for special tokens with "bos"/"begin" or "eos"/"end" in them
266317
// 2. Sub-qualify with the word "text" if needed
267318
// 3. If EOS found, but BOS is not (or vice versa), assume they are the same
268-
else {
319+
if (!eos_found || !bos_found) {
269320
std::vector<std::string_view> bos_candidates;
270321
std::vector<std::string_view> eos_candidates;
271322
for (std::size_t token_idx = 0; token_idx < special_token_map_->size();
272323
++token_idx) {
273324
const auto [token, _] = special_token_map_->getElement(token_idx);
274-
if (token.find("bos") != std::string::npos ||
275-
token.find("begin") != std::string::npos) {
325+
if (!bos_found &&
326+
(token.find("bos") != std::string::npos ||
327+
token.find("begin") != std::string::npos)) {
276328
bos_candidates.push_back(token);
277329
}
278-
if (token.find("eos") != std::string::npos ||
279-
token.find("end") != std::string::npos) {
330+
if (!eos_found &&
331+
(token.find("eos") != std::string::npos ||
332+
token.find("end") != std::string::npos)) {
280333
eos_candidates.push_back(token);
281334
}
282335
}
@@ -300,14 +353,11 @@ Error HFTokenizer::load(const std::string& path) {
300353
}
301354
}
302355

303-
// Use if a single candidate
304-
bool bos_found = false;
305-
bool eos_found = false;
306-
if (bos_candidates.size() == 1) {
356+
if (!bos_found && bos_candidates.size() == 1) {
307357
bos_found = true;
308358
bos_tok_ = *(special_token_map_->tryGetInteger(bos_candidates[0]));
309359
}
310-
if (eos_candidates.size() == 1) {
360+
if (!eos_found && eos_candidates.size() == 1) {
311361
eos_found = true;
312362
eos_tok_ = *(special_token_map_->tryGetInteger(eos_candidates[0]));
313363
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"bos_token": {
3+
"content": "<|begin_of_text|>",
4+
"lstrip": false,
5+
"normalized": false,
6+
"rstrip": false,
7+
"single_word": false
8+
},
9+
"eos_token": {
10+
"content": "<|eot_id|>",
11+
"lstrip": false,
12+
"normalized": false,
13+
"rstrip": false,
14+
"single_word": false
15+
}
16+
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
{
2+
"version": "1.0",
3+
"truncation": null,
4+
"padding": null,
5+
"added_tokens": [
6+
{
7+
"id": 0,
8+
"content": "<unk>",
9+
"single_word": false,
10+
"lstrip": false,
11+
"rstrip": false,
12+
"normalized": false,
13+
"special": true
14+
},
15+
{
16+
"id": 1,
17+
"content": "<s>",
18+
"single_word": false,
19+
"lstrip": false,
20+
"rstrip": false,
21+
"normalized": false,
22+
"special": true
23+
},
24+
{
25+
"id": 2,
26+
"content": "</s>",
27+
"single_word": false,
28+
"lstrip": false,
29+
"rstrip": false,
30+
"normalized": false,
31+
"special": true
32+
},
33+
{
34+
"id": 128000,
35+
"content": "<|begin_of_text|>",
36+
"single_word": false,
37+
"lstrip": false,
38+
"rstrip": false,
39+
"normalized": false,
40+
"special": true
41+
},
42+
{
43+
"id": 128001,
44+
"content": "<|end_of_text|>",
45+
"single_word": false,
46+
"lstrip": false,
47+
"rstrip": false,
48+
"normalized": false,
49+
"special": true
50+
},
51+
{
52+
"id": 128009,
53+
"content": "<|eot_id|>",
54+
"single_word": false,
55+
"lstrip": false,
56+
"rstrip": false,
57+
"normalized": false,
58+
"special": true
59+
}
60+
],
61+
"normalizer": {
62+
"type": "Sequence",
63+
"normalizers": [
64+
{
65+
"type": "Replace",
66+
"pattern": {
67+
"String": " "
68+
},
69+
"content": ""
70+
}
71+
]
72+
},
73+
"pre_tokenizer": {
74+
"type": "Sequence",
75+
"pretokenizers": [
76+
{
77+
"type": "Split",
78+
"pattern": {
79+
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
80+
},
81+
"behavior": "MergedWithPrevious",
82+
"invert": false
83+
},
84+
{
85+
"type": "ByteLevel",
86+
"add_prefix_space": false,
87+
"trim_offsets": false,
88+
"use_regex": false
89+
}
90+
]
91+
},
92+
"post_processor": {
93+
"type": "ByteLevel",
94+
"add_prefix_space": false,
95+
"trim_offsets": false,
96+
"use_regex": false
97+
},
98+
"decoder": {
99+
"type": "ByteLevel",
100+
"add_prefix_space": false,
101+
"trim_offsets": false,
102+
"use_regex": false
103+
},
104+
"model": {
105+
"type": "BPE",
106+
"dropout": null,
107+
"unk_token": null,
108+
"continuing_subword_prefix": "",
109+
"end_of_word_suffix": "",
110+
"fuse_unk": false,
111+
"byte_fallback": false,
112+
"ignore_merges": false,
113+
"vocab": {
114+
"<unk>": 0,
115+
"<s>": 1,
116+
"</s>": 2,
117+
"▁": 3,
118+
"H": 4,
119+
"e": 5,
120+
"l": 6,
121+
"o": 7,
122+
"▁Hello": 8,
123+
"▁world!": 9,
124+
"w": 10,
125+
"r": 11,
126+
"d": 12,
127+
"!": 13
128+
},
129+
"merges": [
130+
"H e",
131+
"e l",
132+
"l l",
133+
"l o",
134+
"▁ H",
135+
"▁H e",
136+
"▁He l",
137+
"▁Hel l",
138+
"▁Hell o",
139+
"w o",
140+
"o r",
141+
"r l",
142+
"l d",
143+
"d !",
144+
"▁ w",
145+
"▁w o",
146+
"▁wo r",
147+
"▁wor l",
148+
"▁worl d",
149+
"▁world !"
150+
]
151+
}
152+
}

0 commit comments

Comments
 (0)