Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

[WIP] File objects #21

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions youtokentome/cpp/bpe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ void rename_tokens(ska::flat_hash_map<uint32_t, uint32_t> &char2id,
}

BPEState learn_bpe_from_string(string &text_utf8, int n_tokens,
const string &output_file,
StreamWriter &output,
BpeConfig bpe_config) {
vector<std::thread> threads;
assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1);
Expand Down Expand Up @@ -1294,8 +1294,8 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens,
rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens);

BPEState bpe_state = {char2id, rules, bpe_config.special_tokens};
bpe_state.dump(output_file);
std::cerr << "model saved to: " << output_file << std::endl;
bpe_state.dump(output);
std::cerr << "model saved to: " << output.name() << std::endl;
return bpe_state;
}

Expand Down Expand Up @@ -1450,7 +1450,8 @@ void train_bpe(const string &input_path, const string &model_path,
std::cerr << "reading file..." << std::endl;
auto data = fast_read_file_utf8(input_path);
std::cerr << "learning bpe..." << std::endl;
learn_bpe_from_string(data, vocab_size, model_path, bpe_config);
auto fout = StreamWriter.open(model_path);
learn_bpe_from_string(data, vocab_size, fout, bpe_config);
}

DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
Expand Down
2 changes: 1 addition & 1 deletion youtokentome/cpp/bpe.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BaseEncoder {

explicit BaseEncoder(BPEState bpe_state, int _n_threads);

explicit BaseEncoder(const std::string& model_path, int n_threads);
explicit BaseEncoder(const StreamReader& model_path, int n_threads);

void fill_from_state();

Expand Down
179 changes: 146 additions & 33 deletions youtokentome/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,112 @@
#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>


namespace vkcom {
using std::string;
using std::vector;

void SpecialTokens::dump(std::ofstream &fout) {
fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
<< std::endl;
class FileWriter : public StreamWriter {
public:
FileWriter(const std::string &file_name) {
this->file_name = file_name;
this->fout = std::ofstream(file_name, std::ios::out | std::ios::binary);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
}

virtual int write(const char *buffer, int size) override {
return fout.write(buffer, size);
}

virtual std::string name() const noexcept override {
return file_name;
}

private:
std::string file_name;
std::ofstream fout;
};

class FileReader : public StreamReader {
public:
FileReader(const std::string &file_name) {
this->file_name = file_name;
this->fin = std::ifstream(file_name, std::ios::in | std::ios::binary);
if (fin.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
}

virtual int read(const char *buffer, int size) override {
return fin.read(buffer, size);
}

virtual std::string name() const noexcept override {
return file_name;
}

private:
std::string file_name;
std::ifstream fin;
};

StreamWriter StreamWriter::open(const std::string &file_name) {
return FileWriter(file_name);
}

StreamReader StreamReader::open(const std::string &file_name) {
return FileReader(file_name);
}

template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
T bin_to_int(const char *val) {
uint32_t ret = static_cast<unsigned char>(val[0]);
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[1])) << 8;
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[2])) << 16;
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[3])) << 24;
return ret;
}

template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
std::unique_ptr<char[]> int_to_bin(T val) {
auto u32 = static_cast<uint32_t>(val);
std::unique_ptr<char[]> ret(new char[4]);
ret[0] = u32 & 0xFF;
ret[1] = (u32 >> 8) & 0xFF;
ret[2] = (u32 >> 16) & 0xFF;
ret[3] = (u32 >> 24); // no need for & 0xFF
return std::move(ret);
}

void SpecialTokens::dump(StreamWriter &fout) {
std::unique_ptr<char[]> unk_id_ptr(int_to_bin(unk_id)),
pad_id_ptr(int_to_bin(pad_id)),
bos_id_ptr(int_to_bin(bos_id)),
eos_id_ptr(int_to_bin(eos_id));
fout.write(unk_id_ptr.get(), 4);
fout.write(pad_id_ptr.get(), 4);
fout.write(bos_id_ptr.get(), 4);
fout.write(eos_id_ptr.get(), 4);
}

void SpecialTokens::load(std::ifstream &fin) {
fin >> unk_id >> pad_id >> bos_id >> eos_id;
void SpecialTokens::load(StreamReader &fin) {
char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4];
fin.read(unk_id_bs, 4);
fin.read(pad_id_bs, 4);
fin.read(bos_id_bs, 4);
fin.read(eos_id_bs, 4);
this->unk_id = bin_to_int<int>(unk_id_bs);
this->pad_id = bin_to_int<int>(pad_id_bs);
this->bos_id = bin_to_int<int>(bos_id_bs);
this->eos_id = bin_to_int<int>(eos_id_bs);
}

uint32_t SpecialTokens::max_id() const {
Expand Down Expand Up @@ -49,48 +141,69 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const {

BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}

void BPEState::dump(const string &file_name) {
std::ofstream fout(file_name, std::ios::out);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
void BPEState::dump(StreamWriter &fout) {
std::unique_ptr<char[]> char2id_ptr(int_to_bin(char2id.size())),
rules_ptr(int_to_bin(rules.size()));
fout.write(char2id_ptr.get(), 4);
fout.write(rules_ptr.get(), 4);
for (auto &s : char2id) {
std::unique_ptr<char[]> first_ptr(int_to_bin(s.first)),
second_ptr(int_to_bin(s.second));
fout.write(first_ptr.get(), 4);
fout.write(second_ptr.get(), 4);
}
fout << char2id.size() << " " << rules.size() << std::endl;
for (auto s : char2id) {
fout << s.first << " " << s.second << std::endl;
for (auto &rule : rules) {
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.x));
fout.write(rule_ptr.get(), 4);
}

for (auto rule : rules) {
fout << rule.x << " " << rule.y << " " << rule.z << std::endl;
for (auto &rule : rules) {
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.y));
fout.write(rule_ptr.get(), 4);
}
for (auto &rule : rules) {
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.z));
fout.write(rule_ptr.get(), 4);
}
special_tokens.dump(fout);
fout.close();
}

void BPEState::load(const string &file_name) {
void BPEState::load(StreamReader &fin) {
char2id.clear();
rules.clear();
std::ifstream fin(file_name, std::ios::in);
if (fin.fail()) {
std::cerr << "Error. Can not open file with model: " << file_name
<< std::endl;
exit(EXIT_FAILURE);
}
int n, m;
fin >> n >> m;
char n_bs[4], m_bs[4];
fin.read(n_bs, 4);
fin.read(m_bs, 4);
auto n = bin_to_int<int>(n_bs);
auto m = bin_to_int<int>(m_bs);
for (int i = 0; i < n; i++) {
uint32_t inner_id;
uint32_t utf32_id;
fin >> inner_id >> utf32_id;
char inner_id_bs[4], utf32_id_bs[4];
fin.read(inner_id_bs, 4);
fin.read(utf32_id_bs, 4);
auto inner_id = bin_to_int<uint32_t>(inner_id_bs);
auto utf32_id = bin_to_int<uint32_t>(utf32_id_bs);
char2id[inner_id] = utf32_id;
}
std::vector<std::tuple<uint32_t, uint32_t, uint32_t>> rules_xyz(m);
for (int j = 0; j < 3; j++) {
for (int i = 0; i < m; i++) {
char val[4];
fin.read(val, 4);
uint32_t *element;
switch (j) {
case 0:
element = &std::get<0>(rules_xyz[i]);
case 1:
element = &std::get<1>(rules_xyz[i]);
case 2:
element = &std::get<2>(rules_xyz[i]);
}
*element = bin_to_int<uint32_t>(val);
}
}
for (int i = 0; i < m; i++) {
uint32_t x, y, z;
fin >> x >> y >> z;
rules.emplace_back(x, y, z);
rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i]));
}
special_tokens.load(fin);
fin.close();
}

BpeConfig::BpeConfig(double _character_coverage, int _n_threads,
Expand Down
24 changes: 20 additions & 4 deletions youtokentome/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
namespace vkcom {
const uint32_t SPACE_TOKEN = 9601;

struct StreamWriter {
virtual int write(const char *buffer, int size) = 0;
virtual std::string name() const noexcept = 0;
virtual ~StreamWriter() = default;

static StreamWriter open(const std::string &file_name);
};

struct StreamReader {
virtual int read(const char *buffer, int size) = 0;
virtual std::string name() const noexcept = 0;
virtual ~StreamReader() = default;

static StreamReader open(const std::string &file_name);
};

struct BPE_Rule {
// x + y -> z
uint32_t x{0};
Expand All @@ -31,9 +47,9 @@ struct SpecialTokens {

SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id);

void dump(std::ofstream &fout);
void dump(StreamWriter &fout);

void load(std::ifstream &fin);
void load(StreamReader &fin);

uint32_t max_id() const;

Expand All @@ -58,9 +74,9 @@ struct BPEState {
std::vector<BPE_Rule> rules;
SpecialTokens special_tokens;

void dump(const std::string &file_name);
void dump(StreamWriter &fout);

void load(const std::string &file_name);
void load(StreamReader &fin);
};

struct DecodeResult {
Expand Down
67 changes: 50 additions & 17 deletions youtokentome/youtokentome.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from enum import Enum
from typing import List, Union
from functools import wraps
from typing import BinaryIO, List, Optional, Union

import _youtokentome_cython


class OutputType(Enum):
ID = 1
SUBWORD = 2


class BPE:
def __init__(self, model: str, n_threads: int = -1):
self.bpe_cython = _youtokentome_cython.BPE(
model_path=model, n_threads=n_threads
)
def __init__(self, model: Union[str, BinaryIO], n_threads: int = -1):
own_obj = isinstance(model, str)
if own_obj:
model = open(model, "rb")
try:
self.bpe_cython = _youtokentome_cython.BPE(
model_fobj=model, n_threads=n_threads
)
finally:
if own_obj:
model.close()

@staticmethod
def train(
data: str,
model: str,
model: Optional[Union[str, BinaryIO]],
vocab_size: int,
coverage: float = 1.0,
n_threads: int = -1,
Expand All @@ -25,17 +35,24 @@ def train(
bos_id: int = 2,
eos_id: int = 3,
) -> "BPE":
_youtokentome_cython.BPE.train(
data=data,
model=model,
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
)
own_obj = isinstance(model, str)
if own_obj:
model = open(model, "wb")
try:
_youtokentome_cython.BPE.train(
data=data,
model=model,
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
)
finally:
if own_obj:
model.close()

return BPE(model=model, n_threads=n_threads)

Expand All @@ -61,6 +78,22 @@ def encode(
reverse=reverse,
)

def save(self, where: Union[str, BinaryIO]):
"""
Write the model to FS or any writeable file object.

:param where: FS path or writeable file object.
:return: None
"""
own_obj = isinstance(where, str)
if own_obj:
where = open(where, "wb")
try:
self.bpe_cython.save(where=where)
finally:
if own_obj:
where.close()

def vocab_size(self) -> int:
return self.bpe_cython.vocab_size()

Expand Down