-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from irinakhismatullina/read
Add function for reading the model dump
- Loading branch information
Showing
6 changed files
with
316 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package bpe | ||
|
||
import ( | ||
"encoding/binary" | ||
"errors" | ||
"io" | ||
|
||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
// TokenID is a numerical identitier of the subword token | ||
type TokenID uint32 | ||
|
||
// EncodedToken is a sequence of subword tokens ids | ||
type EncodedToken []TokenID | ||
|
||
type rule struct { | ||
left TokenID | ||
right TokenID | ||
result TokenID | ||
} | ||
|
||
type specialTokens struct { | ||
unk int32 | ||
pad int32 | ||
bos int32 | ||
eos int32 | ||
} | ||
|
||
// Model is a Byte-Pair encoding model, which supports encoding and decoding text into sequences | ||
// of most frequent subword tokens | ||
type Model struct { | ||
char2id map[rune]TokenID | ||
id2char map[TokenID]rune | ||
rules []rule | ||
recipe map[TokenID]EncodedToken | ||
revRecipe map[string]TokenID | ||
specialTokens specialTokens | ||
} | ||
|
||
func newModel(nRules int) *Model { | ||
return &Model{ | ||
make(map[rune]TokenID), | ||
make(map[TokenID]rune), | ||
make([]rule, nRules), | ||
make(map[TokenID]EncodedToken), | ||
make(map[string]TokenID), | ||
specialTokens{-1, -1, -1, -1}, | ||
} | ||
} | ||
|
||
// DecodeToken converts the sequence of chars' ids into the string - | ||
// sequence of the corresponding chars | ||
func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) { | ||
word := "" | ||
for _, id := range token { | ||
if char, ok := id2char[id]; ok { | ||
word = word + string(char) | ||
} else { | ||
logrus.Errorf("Decode failure: %d token id has no corresponding char", id) | ||
return "", errors.New("key not found in id2char") | ||
} | ||
} | ||
return word, nil | ||
} | ||
|
||
func (s specialTokens) toBinary() []byte { | ||
bytesArray := make([]byte, 16) | ||
binary.BigEndian.PutUint32(bytesArray, uint32(s.unk)) | ||
binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad)) | ||
binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos)) | ||
binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos)) | ||
return bytesArray | ||
} | ||
|
||
func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) { | ||
var s specialTokens | ||
if len(bytesArray) < 16 { | ||
logrus.Error("Bytes array length is too small") | ||
return s, errors.New("bytes array is too small") | ||
} | ||
s.unk = int32(binary.BigEndian.Uint32(bytesArray)) | ||
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:])) | ||
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:])) | ||
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:])) | ||
return s, nil | ||
} | ||
|
||
func (r rule) toBinary() []byte { | ||
bytesArray := make([]byte, 12) | ||
binary.BigEndian.PutUint32(bytesArray, uint32(r.left)) | ||
binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right)) | ||
binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result)) | ||
return bytesArray | ||
} | ||
|
||
func binaryToRule(bytesArray []byte) (rule, error) { | ||
var r rule | ||
if len(bytesArray) < 12 { | ||
logrus.Error("Bytes array length is too small") | ||
return r, errors.New("bytes array is too small") | ||
} | ||
r.left = TokenID(binary.BigEndian.Uint32(bytesArray)) | ||
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:])) | ||
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:])) | ||
return r, nil | ||
} | ||
|
||
// ReadModel loads the BPE model from the binary dump | ||
func ReadModel(reader io.Reader) (*Model, error) { | ||
buf := make([]byte, 4) | ||
var nChars, nRules int | ||
if _, err := io.ReadFull(reader, buf); err != nil { | ||
logrus.Error("Broken input: ", err) | ||
return &Model{}, err | ||
} | ||
nChars = int(binary.BigEndian.Uint32(buf)) | ||
if _, err := io.ReadFull(reader, buf); err != nil { | ||
logrus.Error("Broken input: ", err) | ||
return &Model{}, err | ||
} | ||
nRules = int(binary.BigEndian.Uint32(buf)) | ||
|
||
model := newModel(nRules) | ||
for i := 0; i < nChars; i++ { | ||
var char rune | ||
var charID TokenID | ||
if _, err := io.ReadFull(reader, buf); err != nil { | ||
logrus.Error("Broken input: ", err) | ||
return &Model{}, err | ||
} | ||
char = rune(binary.BigEndian.Uint32(buf)) | ||
if _, err := io.ReadFull(reader, buf); err != nil { | ||
logrus.Error("Broken input: ", err) | ||
return &Model{}, err | ||
} | ||
charID = TokenID(binary.BigEndian.Uint32(buf)) | ||
model.char2id[char] = charID | ||
model.id2char[charID] = char | ||
model.recipe[charID] = EncodedToken{charID} | ||
model.revRecipe[string(char)] = charID | ||
} | ||
ruleBuf := make([]byte, 12) | ||
for i := 0; i < nRules; i++ { | ||
if _, err := io.ReadFull(reader, ruleBuf); err != nil { | ||
logrus.Error("Broken input: ", err) | ||
return &Model{}, err | ||
} | ||
rule, err := binaryToRule(ruleBuf) | ||
if err != nil { | ||
return model, err | ||
} | ||
model.rules[i] = rule | ||
if _, ok := model.recipe[rule.left]; !ok { | ||
logrus.Errorf("%d: token id not described before", rule.left) | ||
return model, errors.New("key not found in id2char") | ||
} | ||
if _, ok := model.recipe[rule.right]; !ok { | ||
logrus.Errorf("%d: token id not described before", rule.right) | ||
return model, errors.New("key not found in id2char") | ||
} | ||
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...) | ||
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char) | ||
if err != nil { | ||
logrus.Error("Unexpected token id inside the rules: ", err) | ||
return model, err | ||
} | ||
model.revRecipe[resultString] = rule.result | ||
} | ||
specialTokensBuf := make([]byte, 16) | ||
if _, err := io.ReadFull(reader, specialTokensBuf); err != nil { | ||
logrus.Error("Broken input: ", err) | ||
return &Model{}, err | ||
} | ||
specials, err := binaryToSpecialTokens(specialTokensBuf) | ||
model.specialTokens = specials | ||
return model, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package bpe | ||
|
||
import ( | ||
"bytes" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestNewModel(t *testing.T) { | ||
model := newModel(10) | ||
require.Equal(t, 10, len(model.rules)) | ||
} | ||
|
||
func TestDecodeToken(t *testing.T) { | ||
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]} | ||
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char) | ||
require.NoError(t, err) | ||
require.Equal(t, "abacc", word) | ||
} | ||
|
||
func TestSpecialTokensToBinary(t *testing.T) { | ||
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127} | ||
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0} | ||
require.Equal(t, bytesArray, specials.toBinary()) | ||
} | ||
|
||
func TestBinaryToSpecialTokens(t *testing.T) { | ||
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0} | ||
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127} | ||
specials, err := binaryToSpecialTokens(bytesArray) | ||
require.NoError(t, err) | ||
require.Equal(t, expected, specials) | ||
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0} | ||
specials, err = binaryToSpecialTokens(bytesArray) | ||
require.Error(t, err) | ||
bytesArray = []byte{} | ||
specials, err = binaryToSpecialTokens(bytesArray) | ||
require.Error(t, err) | ||
} | ||
|
||
func TestRuleToBinary(t *testing.T) { | ||
rule := rule{1, 2, 257} | ||
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1} | ||
require.Equal(t, bytesArray, rule.toBinary()) | ||
} | ||
|
||
func TestBinaryToRule(t *testing.T) { | ||
expected := rule{1, 2, 257} | ||
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1} | ||
rule, err := binaryToRule(bytesArray) | ||
require.NoError(t, err) | ||
require.Equal(t, expected, rule) | ||
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1} | ||
rule, err = binaryToRule(bytesArray) | ||
require.Error(t, err) | ||
bytesArray = []byte{} | ||
rule, err = binaryToRule(bytesArray) | ||
require.Error(t, err) | ||
} | ||
|
||
func TestReadModel(t *testing.T) { | ||
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, | ||
0, 0, 0, 99, 0, 0, 0, 6, | ||
0, 0, 0, 98, 0, 0, 0, 7, | ||
0, 0, 0, 95, 0, 0, 0, 4, | ||
0, 0, 0, 100, 0, 0, 0, 5, | ||
0, 0, 0, 97, 0, 0, 0, 8, | ||
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9, | ||
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, | ||
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, | ||
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, | ||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3}) | ||
expected := Model{ | ||
map[rune]TokenID{97: 8, 98: 7, 99: 6, 100: 5, 95: 4}, | ||
map[TokenID]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97}, | ||
[]rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}}, | ||
map[TokenID]EncodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}}, | ||
map[string]TokenID{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4, | ||
"_a": 9, "_b": 12, "_c": 10, "_d": 11}, | ||
specialTokens{1, 0, 2, 3}, | ||
} | ||
model, err := ReadModel(reader) | ||
require.NoError(t, err) | ||
require.Equal(t, expected, *model) | ||
|
||
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, | ||
0, 0, 0, 99, 0, 0, 0, 6, | ||
0, 0, 0, 98, 0, 0, 0, 7, | ||
0, 0, 0, 95, 0, 0, 0, 4, | ||
0, 0, 0, 100, 0, 0, 0, 5, | ||
0, 0, 0, 97, 0, 0, 0, 8, | ||
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9, | ||
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, | ||
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, | ||
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, | ||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, | ||
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, | ||
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12}) | ||
model, err = ReadModel(reader) | ||
require.NoError(t, err) | ||
require.Equal(t, expected, *model) | ||
|
||
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, | ||
0, 0, 0, 99, 0, 0, 0, 6, | ||
0, 0, 0, 98, 0, 0, 0, 7, | ||
0, 0, 0, 95, 0, 0, 0, 4, | ||
0, 0, 0, 100, 0, 0, 0, 5, | ||
0, 0, 0, 97, 0, 0, 0, 8, | ||
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9, | ||
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, | ||
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, | ||
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, | ||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}) | ||
model, err = ReadModel(reader) | ||
require.Error(t, err) | ||
|
||
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, | ||
0, 0, 0, 99, 0, 0, 0, 6, | ||
0, 0, 0, 98, 0, 0, 0, 7, | ||
0, 0, 0, 95, 0, 0, 0, 4, | ||
0, 0, 0, 100, 0, 0, 0, 5, | ||
0, 0, 0, 97, 0, 0, 0, 8, | ||
0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9, | ||
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, | ||
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, | ||
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, | ||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3}) | ||
model, err = ReadModel(reader) | ||
require.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
module github.com/src-d/go-YouTokenToMe | ||
|
||
go 1.12 | ||
|
||
require ( | ||
github.com/sirupsen/logrus v1.4.2 | ||
github.com/stretchr/testify v1.4.0 | ||
) |