Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/pytorch/tokenizers/bpe_tokenizer_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ inline Result<std::unique_ptr<IRegex>> build_special_token_regex(
if (special_pattern.empty()) {
return static_cast<std::unique_ptr<IRegex>>(nullptr);
}
return create_regex(special_pattern);
// Wrap pattern in parentheses for proper grouping
return create_regex("(" + special_pattern + ")");
}

class BPETokenizerBase : public Tokenizer {
Expand Down
132 changes: 91 additions & 41 deletions src/hf_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,30 @@ using json = nlohmann::json;

namespace tokenizers {

namespace {
// Helper to extract token string from either string or object format
std::string extract_token_string(const json& token_json) {
if (token_json.is_string()) {
return token_json.get<std::string>();
} else if (token_json.is_object() && token_json.contains("content")) {
return token_json["content"].get<std::string>();
}
return "";
};
} // namespace
// -------------------------private method end-------------------------------
// -------------------------public method start-------------------------------

Error HFTokenizer::load(const std::string& path) {
// If this is a directory, look for tokenizer.json and tokenizer_config.json
std::string model_json = path;
std::string model_config_json = "";
std::string special_tokens_map_json;

// Check if bos/eos found.
bool bos_found = false;
bool eos_found = false;

if (fs::is_directory(path)) {
const fs::path root(path);
model_json = (root / "tokenizer.json").string();
Expand All @@ -43,6 +60,11 @@ Error HFTokenizer::load(const std::string& path) {
if (fs::exists(model_config_json_path)) {
model_config_json = model_config_json_path.string();
}

const auto special_tokens_map_json_path = root / "special_tokens_map.json";
if (fs::exists(special_tokens_map_json_path)) {
special_tokens_map_json = special_tokens_map_json_path.string();
}
}

// Load the tokenizer.json file
Expand All @@ -63,7 +85,6 @@ Error HFTokenizer::load(const std::string& path) {

// Parse the special tokens
try {
std::vector<std::pair<std::string, std::uint64_t>> special_token_pairs;
const auto& special_tokens = parsed_json.at("added_tokens");
auto special_token_map_result = detail::build_token_map(
special_tokens,
Expand Down Expand Up @@ -213,8 +234,37 @@ Error HFTokenizer::load(const std::string& path) {
return Error::LoadFailure;
}

// If a tokenizer config file is found, parse it to look up the eos/bos tokens
if (!model_config_json.empty()) {
// Try special_tokens_map.json first
std::string bos_token;
std::string eos_token;

if (!special_tokens_map_json.empty()) {
std::ifstream special_file(special_tokens_map_json);
if (special_file) {
try {
json special_tokens_json = json::parse(std::string(
(std::istreambuf_iterator<char>(special_file)),
std::istreambuf_iterator<char>()));

if (special_tokens_json.contains("bos_token")) {
bos_token = extract_token_string(special_tokens_json["bos_token"]);
}
if (special_tokens_json.contains("eos_token")) {
eos_token = extract_token_string(special_tokens_json["eos_token"]);
}

TK_LOG(
Info,
"Loaded tokens from special_tokens_map.json: bos='%s', eos='%s'",
bos_token.c_str(),
eos_token.c_str());
} catch (const std::exception& e) {
TK_LOG(Info, "Could not parse special_tokens_map.json: %s", e.what());
}
}
}
// Try tokenizer_config.json next
if ((bos_token.empty() || eos_token.empty()) && !model_config_json.empty()) {
// Load it and parse it as json
std::ifstream config_file(model_config_json);
if (!config_file) {
Expand All @@ -224,59 +274,62 @@ Error HFTokenizer::load(const std::string& path) {
std::string config_contents(
(std::istreambuf_iterator<char>(config_file)),
std::istreambuf_iterator<char>());
json parsed_config_json;
try {
parsed_config_json = json::parse(config_contents);
json parsed_config_json = json::parse(config_contents);
if (bos_token.empty() && parsed_config_json.contains("bos_token")) {
bos_token = extract_token_string(parsed_config_json["bos_token"]);
}
if (eos_token.empty() && parsed_config_json.contains("eos_token")) {
eos_token = extract_token_string(parsed_config_json["eos_token"]);
}
TK_LOG(
Info,
"Loaded tokens from tokenizer_config.json: bos='%s', eos='%s'",
bos_token.c_str(),
eos_token.c_str());
} catch (const std::exception& e) {
TK_LOG(Error, "Error parsing model config json json file: %s", e.what());
return Error::LoadFailure;
}
}

// Pull out the token strings
try {
const std::string bos_token = parsed_config_json.contains("bos_token") &&
!parsed_config_json["bos_token"].is_null()
? parsed_config_json["bos_token"].get<std::string>()
: "";

const std::string eos_token = parsed_config_json.contains("eos_token") &&
!parsed_config_json["eos_token"].is_null()
? parsed_config_json["eos_token"].get<std::string>()
: "";
const auto bos_res = special_token_map_->tryGetInteger(bos_token);
const auto eos_res = special_token_map_->tryGetInteger(eos_token);
if (!bos_res) {
TK_LOG(Error, "BOS token %s not in special tokens", bos_token.c_str());
return Error::LoadFailure;
}
if (!eos_res) {
TK_LOG(Error, "EOS token %s not in special tokens", eos_token.c_str());
return Error::LoadFailure;
}
bos_tok_ = *bos_res;
eos_tok_ = *eos_res;
} catch (const std::exception& e) {
TK_LOG(Error, "Could not eos/bos from tokenizer config: %s", e.what());
return Error::LoadFailure;
// Try to extract the bos/eos tokens.
if (!bos_token.empty() && !eos_token.empty()) {
auto bos_candidate = special_token_map_->tryGetInteger(bos_token);
if (!bos_candidate) {
TK_LOG(Info, "BOS token %s not in special tokens", bos_token.c_str());
} else {
bos_tok_ = *bos_candidate;
bos_found = true;
}

auto eos_candidate = special_token_map_->tryGetInteger(eos_token);
if (!eos_candidate) {
TK_LOG(Info, "EOS token %s not in special tokens", eos_token.c_str());
} else {
eos_tok_ = *eos_candidate;
eos_found = true;
}
}

// Otherwise, make an educated guess with the following logic:
// 1. Look for special tokens with "bos"/"begin" or "eos"/"end" in them
// 2. Sub-qualify with the word "text" if needed
// 3. If EOS found, but BOS is not (or vice versa), assume they are the same
else {
if (!eos_found || !bos_found) {
std::vector<std::string_view> bos_candidates;
std::vector<std::string_view> eos_candidates;
for (std::size_t token_idx = 0; token_idx < special_token_map_->size();
++token_idx) {
const auto [token, _] = special_token_map_->getElement(token_idx);
if (token.find("bos") != std::string::npos ||
token.find("begin") != std::string::npos) {
if (!bos_found &&
(token.find("bos") != std::string::npos ||
token.find("begin") != std::string::npos)) {
bos_candidates.push_back(token);
}
if (token.find("eos") != std::string::npos ||
token.find("end") != std::string::npos) {
if (!eos_found &&
(token.find("eos") != std::string::npos ||
token.find("end") != std::string::npos)) {
eos_candidates.push_back(token);
}
}
Expand All @@ -300,14 +353,11 @@ Error HFTokenizer::load(const std::string& path) {
}
}

// Use if a single candidate
bool bos_found = false;
bool eos_found = false;
if (bos_candidates.size() == 1) {
if (!bos_found && bos_candidates.size() == 1) {
bos_found = true;
bos_tok_ = *(special_token_map_->tryGetInteger(bos_candidates[0]));
}
if (eos_candidates.size() == 1) {
if (!eos_found && eos_candidates.size() == 1) {
eos_found = true;
eos_tok_ = *(special_token_map_->tryGetInteger(eos_candidates[0]));
}
Expand Down
16 changes: 16 additions & 0 deletions test/resources/hf_tokenizer_dir/special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"bos_token": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|eot_id|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
152 changes: 152 additions & 0 deletions test/resources/hf_tokenizer_dir/tokenizer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": "<unk>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 1,
"content": "<s>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 2,
"content": "</s>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128000,
"content": "<|begin_of_text|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128001,
"content": "<|end_of_text|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128009,
"content": "<|eot_id|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": {
"type": "Sequence",
"normalizers": [
{
"type": "Replace",
"pattern": {
"String": " "
},
"content": "▁"
}
]
},
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
},
"behavior": "MergedWithPrevious",
"invert": false
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": false,
"use_regex": false
}
]
},
"post_processor": {
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": false,
"use_regex": false
},
"decoder": {
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": false,
"use_regex": false
},
"model": {
"type": "BPE",
"dropout": null,
"unk_token": null,
"continuing_subword_prefix": "",
"end_of_word_suffix": "",
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": false,
"vocab": {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"▁": 3,
"H": 4,
"e": 5,
"l": 6,
"o": 7,
"▁Hello": 8,
"▁world!": 9,
"w": 10,
"r": 11,
"d": 12,
"!": 13
},
"merges": [
"H e",
"e l",
"l l",
"l o",
"▁ H",
"▁H e",
"▁He l",
"▁Hel l",
"▁Hell o",
"w o",
"o r",
"r l",
"l d",
"d !",
"▁ w",
"▁w o",
"▁wo r",
"▁wor l",
"▁worl d",
"▁world !"
]
}
}
Loading