Skip to content

Commit

Permalink
Merge pull request #2 from irinakhismatullina/read
Browse files Browse the repository at this point in the history
Add function for reading the model dump
  • Loading branch information
vmarkovtsev authored Oct 22, 2019
2 parents 76ebe6c + 5a79713 commit 87388a2
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ install:
script:
- make install-dev-deps
- make check-style
- make test
- make test-coverage
- make codecov

matrix:
fast_finish: true
Expand Down
178 changes: 178 additions & 0 deletions bpe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package bpe

import (
"encoding/binary"
"errors"
"io"

"github.com/sirupsen/logrus"
)

// TokenID is a numerical identitier of the subword token
type TokenID uint32

// EncodedToken is a sequence of subword tokens ids
type EncodedToken []TokenID

type rule struct {
left TokenID
right TokenID
result TokenID
}

type specialTokens struct {
unk int32
pad int32
bos int32
eos int32
}

// Model is a Byte-Pair encoding model, which supports encoding and decoding text into sequences
// of most frequent subword tokens
type Model struct {
char2id map[rune]TokenID
id2char map[TokenID]rune
rules []rule
recipe map[TokenID]EncodedToken
revRecipe map[string]TokenID
specialTokens specialTokens
}

func newModel(nRules int) *Model {
return &Model{
make(map[rune]TokenID),
make(map[TokenID]rune),
make([]rule, nRules),
make(map[TokenID]EncodedToken),
make(map[string]TokenID),
specialTokens{-1, -1, -1, -1},
}
}

// DecodeToken converts the sequence of chars' ids into the string -
// sequence of the corresponding chars
func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) {
word := ""
for _, id := range token {
if char, ok := id2char[id]; ok {
word = word + string(char)
} else {
logrus.Errorf("Decode failure: %d token id has no corresponding char", id)
return "", errors.New("key not found in id2char")
}
}
return word, nil
}

func (s specialTokens) toBinary() []byte {
bytesArray := make([]byte, 16)
binary.BigEndian.PutUint32(bytesArray, uint32(s.unk))
binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad))
binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos))
binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos))
return bytesArray
}

func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) {
var s specialTokens
if len(bytesArray) < 16 {
logrus.Error("Bytes array length is too small")
return s, errors.New("bytes array is too small")
}
s.unk = int32(binary.BigEndian.Uint32(bytesArray))
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:]))
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:]))
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:]))
return s, nil
}

func (r rule) toBinary() []byte {
bytesArray := make([]byte, 12)
binary.BigEndian.PutUint32(bytesArray, uint32(r.left))
binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right))
binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result))
return bytesArray
}

func binaryToRule(bytesArray []byte) (rule, error) {
var r rule
if len(bytesArray) < 12 {
logrus.Error("Bytes array length is too small")
return r, errors.New("bytes array is too small")
}
r.left = TokenID(binary.BigEndian.Uint32(bytesArray))
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:]))
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:]))
return r, nil
}

// ReadModel loads the BPE model from the binary dump
func ReadModel(reader io.Reader) (*Model, error) {
buf := make([]byte, 4)
var nChars, nRules int
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
nChars = int(binary.BigEndian.Uint32(buf))
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
nRules = int(binary.BigEndian.Uint32(buf))

model := newModel(nRules)
for i := 0; i < nChars; i++ {
var char rune
var charID TokenID
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
char = rune(binary.BigEndian.Uint32(buf))
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
charID = TokenID(binary.BigEndian.Uint32(buf))
model.char2id[char] = charID
model.id2char[charID] = char
model.recipe[charID] = EncodedToken{charID}
model.revRecipe[string(char)] = charID
}
ruleBuf := make([]byte, 12)
for i := 0; i < nRules; i++ {
if _, err := io.ReadFull(reader, ruleBuf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
rule, err := binaryToRule(ruleBuf)
if err != nil {
return model, err
}
model.rules[i] = rule
if _, ok := model.recipe[rule.left]; !ok {
logrus.Errorf("%d: token id not described before", rule.left)
return model, errors.New("key not found in id2char")
}
if _, ok := model.recipe[rule.right]; !ok {
logrus.Errorf("%d: token id not described before", rule.right)
return model, errors.New("key not found in id2char")
}
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
if err != nil {
logrus.Error("Unexpected token id inside the rules: ", err)
return model, err
}
model.revRecipe[resultString] = rule.result
}
specialTokensBuf := make([]byte, 16)
if _, err := io.ReadFull(reader, specialTokensBuf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
specials, err := binaryToSpecialTokens(specialTokensBuf)
model.specialTokens = specials
return model, err
}
131 changes: 131 additions & 0 deletions bpe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package bpe

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"
)

func TestNewModel(t *testing.T) {
model := newModel(10)
require.Equal(t, 10, len(model.rules))
}

func TestDecodeToken(t *testing.T) {
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char)
require.NoError(t, err)
require.Equal(t, "abacc", word)
}

func TestSpecialTokensToBinary(t *testing.T) {
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
require.Equal(t, bytesArray, specials.toBinary())
}

func TestBinaryToSpecialTokens(t *testing.T) {
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
specials, err := binaryToSpecialTokens(bytesArray)
require.NoError(t, err)
require.Equal(t, expected, specials)
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0}
specials, err = binaryToSpecialTokens(bytesArray)
require.Error(t, err)
bytesArray = []byte{}
specials, err = binaryToSpecialTokens(bytesArray)
require.Error(t, err)
}

func TestRuleToBinary(t *testing.T) {
rule := rule{1, 2, 257}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
require.Equal(t, bytesArray, rule.toBinary())
}

func TestBinaryToRule(t *testing.T) {
expected := rule{1, 2, 257}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
rule, err := binaryToRule(bytesArray)
require.NoError(t, err)
require.Equal(t, expected, rule)
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1}
rule, err = binaryToRule(bytesArray)
require.Error(t, err)
bytesArray = []byte{}
rule, err = binaryToRule(bytesArray)
require.Error(t, err)
}

func TestReadModel(t *testing.T) {
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
expected := Model{
map[rune]TokenID{97: 8, 98: 7, 99: 6, 100: 5, 95: 4},
map[TokenID]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97},
[]rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}},
map[TokenID]EncodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}},
map[string]TokenID{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4,
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
specialTokens{1, 0, 2, 3},
}
model, err := ReadModel(reader)
require.NoError(t, err)
require.Equal(t, expected, *model)

reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12})
model, err = ReadModel(reader)
require.NoError(t, err)
require.Equal(t, expected, *model)

reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0})
model, err = ReadModel(reader)
require.Error(t, err)

reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
model, err = ReadModel(reader)
require.Error(t, err)
}
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module github.com/src-d/go-YouTokenToMe

go 1.12

require (
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
)
Empty file removed go.sum
Empty file.
7 changes: 0 additions & 7 deletions main.go

This file was deleted.

0 comments on commit 87388a2

Please sign in to comment.