Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : improve BPE pre-processing + LLaMA 3 and Deepseek support #6920

Merged
merged 61 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
6fbab2d
merged the changes from deepseeker models to main branch
jaggzh Feb 12, 2024
d2cfc22
Moved regex patterns to unicode.cpp and updated unicode.h
dragnil1 Mar 22, 2024
54f93eb
Moved header files
dragnil1 Mar 22, 2024
1c924e4
Resolved issues
dragnil1 Mar 23, 2024
4056dc5
added and refactored unicode_regex_split and related functions
dragnil1 Mar 31, 2024
c8e7d95
Updated/merged the deepseek coder pr
jaggzh Feb 12, 2024
4c3e882
Refactored code
dragnil1 Apr 13, 2024
a5710a4
Adding unicode regex mappings
dragnil1 Apr 15, 2024
7e308ed
Adding unicode regex function
dragnil1 Apr 15, 2024
feeaf4f
Added needed functionality, testing remains
dragnil1 Apr 15, 2024
7535803
Fixed issues
dragnil1 Apr 15, 2024
36d9832
Fixed issue with gpt2 regex custom preprocessor
dragnil1 Apr 17, 2024
06d3e69
unicode : fix? unicode_wstring_to_utf8
ggerganov Apr 26, 2024
c56e19d
lint : fix whitespaces
ggerganov Apr 26, 2024
7a44e44
tests : add tokenizer tests for numbers
ggerganov Apr 26, 2024
d999cf6
unicode : remove redundant headers
ggerganov Apr 26, 2024
aeafb43
tests : remove and rename tokenizer test scripts
ggerganov Apr 26, 2024
e1b2bf7
tests : add sample usage
ggerganov Apr 26, 2024
ed42711
gguf-py : reader prints warnings on duplicate keys
ggerganov Apr 26, 2024
4907e41
llama : towards llama3 tokenization support (wip)
ggerganov Apr 26, 2024
e8c206b
unicode : shot in the dark to fix tests on Windows
ggerganov Apr 26, 2024
e989176
unicode : first try custom implementations
ggerganov Apr 26, 2024
e3f6dc7
Merge branch 'master' into gg/bpe-preprocess
ggerganov Apr 26, 2024
9b4d63a
convert : add "tokenizer.ggml.pre" GGUF KV (wip)
ggerganov Apr 26, 2024
43e12ce
llama : use new pre-tokenizer type
ggerganov Apr 26, 2024
1b9b79d
convert : fix pre-tokenizer type writing
ggerganov Apr 26, 2024
8791e94
lint : fix
ggerganov Apr 26, 2024
a774d70
make : add test-tokenizer-0-llama-v3
ggerganov Apr 26, 2024
c160818
wip
ggerganov Apr 26, 2024
96965f6
models : add llama v3 vocab file
ggerganov Apr 27, 2024
ad92983
llama : adapt punctuation regex + add llama 3 regex
ggerganov Apr 27, 2024
4434c9d
minor
ggerganov Apr 27, 2024
a22645c
unicode : set bomb
ggerganov Apr 27, 2024
2affd0b
unicode : set bomb
ggerganov Apr 27, 2024
ce5485a
unicode : always use std::wregex
ggerganov Apr 27, 2024
91eaa41
unicode : support \p{N}, \p{L} and \p{P} natively
ggerganov Apr 27, 2024
581c4a0
unicode : try fix windows
ggerganov Apr 27, 2024
b97add5
unicode : category support via std::regex
ggerganov Apr 28, 2024
d63cc90
Merge branch 'master' into gg/bpe-preprocess
ggerganov Apr 28, 2024
e972e6c
unicode : clean-up
ggerganov Apr 28, 2024
ee6d1b3
unicode : simplify
ggerganov Apr 28, 2024
7642973
convert : add convert-hf-to-gguf-update.py
ggerganov Apr 28, 2024
4e3e6d8
lint : update
ggerganov Apr 28, 2024
1c888eb
convert : add falcon
ggerganov Apr 28, 2024
1545550
unicode : normalize signatures
ggerganov Apr 28, 2024
491f233
lint : fix
ggerganov Apr 28, 2024
e8dd4a1
lint : fix
ggerganov Apr 28, 2024
02fd977
convert : remove unused functions
ggerganov Apr 28, 2024
0f9058c
convert : add comments
ggerganov Apr 28, 2024
7808150
convert : exercise contractions
ggerganov Apr 28, 2024
7b1210f
lint : fix
ggerganov Apr 28, 2024
ef4cca9
cmake : refactor test targets
ggerganov Apr 29, 2024
43708d2
tests : refactor vocab tests
ggerganov Apr 29, 2024
c68d259
tests : add more vocabs and tests
ggerganov Apr 29, 2024
af05268
unicode : cleanup
ggerganov Apr 29, 2024
c21ab18
scripts : ignore new update script in check-requirements.sh
ggerganov Apr 29, 2024
120cf37
models : add phi-3, mpt, gpt-2, starcoder
ggerganov Apr 29, 2024
9a7d430
tests : disable obsolete
ggerganov Apr 29, 2024
6d6ce93
tests : use faster bpe test
ggerganov Apr 29, 2024
3202676
llama : more prominent warning for old BPE models
ggerganov Apr 29, 2024
80cb312
tests : disable test-tokenizer-1-bpe due to slowness
ggerganov Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unicode-data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1666,4 +1666,4 @@ const std::map<std::string, std::wstring> unicode_regex_equivalent_wregex = {

const std::set<std::string> unicode_regex_with_custom_preprocessor = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)"
};
};
2 changes: 1 addition & 1 deletion unicode-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::map<std::string, std::wstring> unicode_regex_equivalent_wregex;
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;
36 changes: 17 additions & 19 deletions unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,14 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
return map;
}

static inline std::wstring unicode_wstring_from_utf8(const std::string & s)
{
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
return conv.from_bytes(s);
}

static inline std::string unicode_wstring_to_utf8(const std::wstring & ws)
{
// code to convert from utf32/utf16 to utf8
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>, wchar_t> converter;
std::string utf8 = converter.to_bytes(ws);
return utf8;
static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
return conv.to_bytes(ws);
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dragnil1 Not sure if this is the intent, but the following change of this function makes the tokenizer tests pass on my Mac. Do you think this is OK to change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change converts UCS-2 or UCS-4/UTF-32 encoded std::wstring to UTF-8 encoded std::string and the previous one, converts UTF-16 encoded std::wstring to UTF-8 encoded std::string according to reference. Both works on Ubuntu(tested) but I am not sure about windows as it uses UTF-16 encoded std::wstring.

}

static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
Expand All @@ -233,7 +229,7 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;

for(auto offset : offsets) {
for (auto offset : offsets) {
const std::string text = unicode_wstring_to_utf8(std::wstring(wtext, start, offset));

std::string token = "";
Expand All @@ -248,15 +244,17 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
text_utf.reserve(text.size());

const auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i)
for (size_t i = 0; i < cpts.size(); ++i) {
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
}

for (int i = 0; i < (int)text_utf.size(); i++) {
const std::string & utf_char = text_utf[i];
bool split_condition = false;
int bytes_remain = text_utf.size() - i;

// forward backward lookups
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";

// handling contractions
Expand Down Expand Up @@ -357,6 +355,7 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
token += utf_char;
}
}

start += offset;
}

Expand Down Expand Up @@ -402,8 +401,8 @@ static bool unicode_regex_with_custom_preprocessor_exists(const std::string & re

static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::wstring & wtext, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
if(regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {

if (regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_gpt2_regex_preprocess(wtext, offsets);
}

Expand Down Expand Up @@ -491,16 +490,15 @@ char32_t unicode_tolower(char32_t cp) {
auto it = unicode_map_lowercase.find(cp);
return it == unicode_map_lowercase.end() ? cp : it->second;
}

std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
std::wstring wtext = unicode_wstring_from_utf8(text);

std::vector<size_t> bpe_offsets = {wtext.size()};

for(auto & regex_expr : regex_exprs) {

for (auto & regex_expr : regex_exprs) {
if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
const std::wstring& wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
} else if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
Expand All @@ -512,10 +510,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for(size_t & offset : bpe_offsets) {
for (size_t & offset : bpe_offsets) {
bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
start += offset;
}

return unicode_byte_encoding_process(bpe_words);
}
}
Loading