diff --git a/README.md b/README.md index 6b707c2..a7f3584 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,9 @@ test_text = "".join([random.choice("abcde ") for _ in range(100)]) # Training model yttm.BPE.train(data=train_data_path, vocab_size=5000, model=model_path) +# Training model with custom tokens +yttm.BPE.train(data=train_data_path, vocab_size=5000, model=model_path, custom_tokens=[b"[CLS]", b"[MASK]"]) + # Loading model bpe = yttm.BPE(model=model_path) @@ -71,7 +74,7 @@ print(bpe.encode([test_text], output_type=yttm.OutputType.SUBWORD))   ### Training model ```python -youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3) +youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3, custom_tokens=[]) ``` Trains BPE model and saves to file. @@ -86,6 +89,7 @@ Trains BPE model and saves to file. * `unk_id`: int, reserved id for unknown symbols * `bos_id`: int, reserved id for begin of sentence token * `eos_id`: int, reserved id for end of sentence token +* `custom_tokens`: List[bytes], tokens which will not be split into subwords. **Returns**: Class `youtokentome.BPE` with the loaded model. @@ -191,7 +195,7 @@ Convert each id to subword and concatenate with space symbol. ### Example ```bash -$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 +$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 --custom_tokens "[CLS],[MASK]" $ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA ``` @@ -234,6 +238,9 @@ Options: --unk_id INTEGER Unknown token id. [default: 1] --bos_id INTEGER 'Begin of sentence' token id. [default: 2] --eos_id INTEGER 'End of sentence' token id. [default: 3] + --custom_tokens TEXT Tokens which will not be split into + subwords, muiltple tokens should be + provided with comma seperated. --help Show this message and exit. ``` diff --git a/tests/unit_tests/test_manual.py b/tests/unit_tests/test_manual.py index c4f7c9d..29a2851 100644 --- a/tests/unit_tests/test_manual.py +++ b/tests/unit_tests/test_manual.py @@ -73,3 +73,26 @@ def test_japanese(): assert tokenized_text == expected_result print(tokenized_text) os.remove(TRAIN_DATA_PATH) + +def test_special_token(): + train_text = """ + [CLS] Lorem ipsum dolor sit amet, consectetur adipiscing elit, + sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in + reprehenderit in voluptate velit [MASK] esse cillum dolore eu fugiat nulla + pariatur. Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum. + """ + test_text = "[CLS] Lorem ipsum [TOKEN] dolor sit [MASK] amet" + TRAIN_DATA_PATH = "train_data.txt" + MODEL_PATH = "model.yttm" + with open(TRAIN_DATA_PATH, "w") as fin: + fin.write(train_text) + model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 100, custom_tokens=[b'[CLS]',b'[TOKEN]',b'']) + tokenized_text = model.encode([test_text], output_type=yttm.OutputType.SUBWORD) + expected_result = [['▁','[CLS]', '▁', 'L', 'or', 'e', 'm', '▁', 'ip', 's', 'um', '▁', '[TOKEN]', '▁dolor', '▁', '', '▁s', 'it', '▁', '[', 'M', 'A', 'S', 'K', ']', '▁a', 'm', 'e', 't']] + print(tokenized_text) + assert tokenized_text == expected_result + print(tokenized_text) + os.remove(TRAIN_DATA_PATH) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index c28ee8a..f0ed155 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1515,6 +1515,12 @@ void print_config(const string &input_path, const string &model_path, std::cerr << " unk: " << bpe_config.special_tokens.unk_id << std::endl; std::cerr << " bos: " << bpe_config.special_tokens.bos_id << std::endl; std::cerr << " eos: " << bpe_config.special_tokens.eos_id << std::endl; + if (bpe_config.special_tokens.custom_tokens.size()) { + std::cerr << " custom_tokens: "; + for (auto token:bpe_config.special_tokens.custom_tokens) + std::cerr << token << " "; + std::cerr << std::endl; + } std::cerr << std::endl; } @@ -1665,6 +1671,7 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, uint32_t new_token_cur = new_tokens_start; list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0); + string utf8_text; for (auto it_char_in_word = begin_of_word; it_char_in_word < end_of_word;) { if (bpe_state.char2id.count(*it_char_in_word) == 0) { @@ -1674,15 +1681,31 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, unrecognized_tokens[new_token_cur] = encode_utf8({it_char_in_word, it_unrecognized_word}); + if (custom_token2id.size()) + utf8_text.append(unrecognized_tokens[new_token_cur]); it_char_in_word = it_unrecognized_word; list.emplace_back(new_token_cur, list.size()); new_token_cur++; } else { + if (custom_token2id.size()) + utf8_to_chars(*it_char_in_word, std::back_inserter(utf8_text)); list.emplace_back(bpe_state.char2id.at(*it_char_in_word), list.size()); ++it_char_in_word; } } + + if (custom_token2id.size() && custom_token2id.count(utf8_text)) { + if (output_type == ID) { + output_ids.push_back(bpe_state.char2id.at(SPACE_TOKEN)); + output_ids.push_back(custom_token2id.find(utf8_text) -> second); + } else { + output_pieces.push_back(encode_utf8({SPACE_TOKEN})); + output_pieces.push_back(utf8_text); + } + continue; + } + list.back().next = -1; @@ -1840,6 +1863,11 @@ void BaseEncoder::fill_from_state() { } reversed_recipe[BOS_TOKEN] = bpe_state.special_tokens.bos_id; reversed_recipe[EOS_TOKEN] = bpe_state.special_tokens.eos_id; + uint32_t custom_id = bpe_state.special_tokens.max_predefined_id(); + for (auto token : bpe_state.special_tokens.custom_tokens) { + ++custom_id; + custom_token2id[token] = custom_id; + } } int BaseEncoder::vocab_size() const { @@ -1947,6 +1975,10 @@ Status BaseEncoder::id_to_subword(int id, string *subword, bool replace_space) c *subword = EOS_TOKEN; return Status(); } + if (id <= bpe_state.special_tokens.max_id() && id > bpe_state.special_tokens.max_predefined_id()) { + *subword = bpe_state.special_tokens.custom_tokens[id - bpe_state.special_tokens.max_predefined_id() - 1]; + return Status(); + } assert(recipe.count(id)); if (replace_space) { diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 99464a2..cac5063 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -27,6 +27,7 @@ class BaseEncoder { flat_hash_map id2char; flat_hash_map> recipe; flat_hash_map reversed_recipe; + flat_hash_map custom_token2id; flat_hash_map rule2id; int n_threads; diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index ec34831..d51cb66 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -8,6 +8,8 @@ constexpr static uint32_t INVALID_UNICODE = 0x0fffffff; uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len); +void utf8_to_chars(const uint32_t x, const std::back_insert_iterator it); + std::string encode_utf8(const std::vector &utext); std::vector decode_utf8(const char *begin, const char *end); diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 768a817..a74af24 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -10,15 +10,20 @@ using std::string; using std::vector; void SpecialTokens::dump(std::ofstream &fout) { - fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id - << std::endl; + fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id << " "; + for (auto token: custom_tokens) fout << token << " "; + fout << std::endl; + } void SpecialTokens::load(std::ifstream &fin) { fin >> unk_id >> pad_id >> bos_id >> eos_id; + std::string token; + while (fin >> token) + custom_tokens.push_back(token); } -uint32_t SpecialTokens::max_id() const { +uint32_t SpecialTokens::max_predefined_id() const { int ret = 0; ret = std::max(ret, unk_id); ret = std::max(ret, pad_id); @@ -27,8 +32,14 @@ uint32_t SpecialTokens::max_id() const { return ret; } +uint32_t SpecialTokens::max_id() const { + int ret = max_predefined_id(); + ret += custom_tokens.size(); + return ret; +} + bool SpecialTokens::taken_id(int id) const { - return id == unk_id || id == pad_id || id == bos_id || id == eos_id; + return id == unk_id || id == pad_id || id == bos_id || id == eos_id || (id > max_predefined_id() && id <= max_id()); } uint64_t SpecialTokens::n_special_tokens() const { @@ -37,6 +48,7 @@ uint64_t SpecialTokens::n_special_tokens() const { cnt += (pad_id != -1); cnt += (bos_id != -1); cnt += (eos_id != -1); + cnt += custom_tokens.size(); return cnt; } diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index ce802d5..4a5102c 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -26,6 +26,7 @@ struct SpecialTokens { int unk_id = -1; int bos_id = -1; int eos_id = -1; + std::vector custom_tokens; SpecialTokens() = default; @@ -40,6 +41,7 @@ struct SpecialTokens { bool taken_id(int id) const; uint64_t n_special_tokens() const; + uint32_t max_predefined_id() const; }; struct BpeConfig { diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index 1d7774d..571f13a 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -14,6 +14,7 @@ cdef extern from "bpe.h" namespace "vkcom": int unk_id int bos_id int eos_id + vector[string] custom_tokens cdef cppclass BpeConfig: double character_coverage @@ -67,6 +68,7 @@ cdef class BPE: vocab_size, coverage=1.0, n_threads=-1, + custom_tokens=[], pad_id=0, unk_id=1, bos_id=2, @@ -79,6 +81,7 @@ cdef class BPE: bpe_config.special_tokens.unk_id = unk_id bpe_config.special_tokens.bos_id = bos_id bpe_config.special_tokens.eos_id = eos_id + bpe_config.special_tokens.custom_tokens = custom_tokens cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config) if status.code != 0: diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index 593febf..8cd1eb9 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -22,6 +22,7 @@ def train( data: str, model: str, vocab_size: int, + custom_tokens: List[bytes] = [], coverage: float = 1.0, n_threads: int = -1, pad_id: int = 0, @@ -35,6 +36,7 @@ def train( vocab_size=vocab_size, n_threads=n_threads, coverage=coverage, + custom_tokens=custom_tokens, pad_id=pad_id, unk_id=unk_id, bos_id=bos_id, diff --git a/youtokentome/yttm_cli.py b/youtokentome/yttm_cli.py index 7e66879..318aea2 100644 --- a/youtokentome/yttm_cli.py +++ b/youtokentome/yttm_cli.py @@ -57,7 +57,14 @@ def main(): default=3, show_default=True, ) -def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id): +@click.option( + "--custom_tokens", + type=click.STRING, + help="Tokens which will not be split into subwords, muiltple tokens should be provided with comma seperated.", + default="", + show_default=True, +) +def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id, custom_tokens): """Train BPE model.""" yttmc.BPE.train( data=data, @@ -69,6 +76,7 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo unk_id=unk_id, bos_id=bos_id, eos_id=eos_id, + custom_tokens=map(lambda t: t.encode("utf8"), custom_tokens.split(',')) )