Skip to content

Commit

Permalink
Remove Token constants information from keys (#2197)
Browse files Browse the repository at this point in the history
Signed-off-by: David Boehm <91908103+dboehm-avalabs@users.noreply.github.com>
Signed-off-by: Dan Laine <daniel.laine@avalabs.org>
Co-authored-by: Darioush Jalali <darioush.jalali@avalabs.org>
Co-authored-by: Dan Laine <daniel.laine@avalabs.org>
  • Loading branch information
3 people authored Nov 6, 2023
1 parent 558d8fb commit e710899
Show file tree
Hide file tree
Showing 27 changed files with 955 additions and 1,087 deletions.
53 changes: 27 additions & 26 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ var (
trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

errTooManyChildren = errors.New("length of children list is larger than branching factor")
errChildIndexTooLarge = errors.New("invalid child index. Must be less than branching factor")
errLeadingZeroes = errors.New("varint has leading zeroes")
errInvalidBool = errors.New("decoded bool is neither true nor false")
Expand All @@ -63,13 +62,15 @@ type encoderDecoder interface {
type encoder interface {
// Assumes [n] is non-nil.
encodeDBNode(n *dbNode) []byte
// Assumes [hv] is non-nil.
encodeHashValues(hv *hashValues) []byte

// Returns the bytes that will be hashed to generate [n]'s ID.
// Assumes [n] is non-nil.
encodeHashValues(n *node) []byte
}

type decoder interface {
// Assumes [n] is non-nil.
decodeDBNode(bytes []byte, n *dbNode, factor BranchFactor) error
decodeDBNode(bytes []byte, n *dbNode) error
}

func newCodec() encoderDecoder {
Expand Down Expand Up @@ -114,9 +115,9 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
return buf.Bytes()
}

func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
func (c *codecImpl) encodeHashValues(n *node) []byte {
var (
numChildren = len(hv.Children)
numChildren = len(n.children)
// Estimate size [hv] to prevent memory allocations
estimatedLen = minVarIntLen + numChildren*hashValuesChildLen + estimatedValueLen + estimatedKeyLen
buf = bytes.NewBuffer(make([]byte, 0, estimatedLen))
Expand All @@ -125,19 +126,20 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
c.encodeUint(buf, uint64(numChildren))

// ensure that the order of entries is consistent
for index := 0; BranchFactor(index) < hv.Key.branchFactor; index++ {
if entry, ok := hv.Children[byte(index)]; ok {
c.encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
keys := maps.Keys(n.children)
slices.Sort(keys)
for _, index := range keys {
entry := n.children[index]
c.encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
c.encodeMaybeByteSlice(buf, hv.Value)
c.encodeKey(buf, hv.Key)
c.encodeMaybeByteSlice(buf, n.valueDigest)
c.encodeKey(buf, n.key)

return buf.Bytes()
}

func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor) error {
func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
if minDBNodeLen > len(b) {
return io.ErrUnexpectedEOF
}
Expand All @@ -154,25 +156,23 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor)
switch {
case err != nil:
return err
case numChildren > uint64(branchFactor):
return errTooManyChildren
case numChildren > uint64(src.Len()/minChildLen):
return io.ErrUnexpectedEOF
}

n.children = make(map[byte]child, branchFactor)
n.children = make(map[byte]child, numChildren)
var previousChild uint64
for i := uint64(0); i < numChildren; i++ {
index, err := c.decodeUint(src)
if err != nil {
return err
}
if index >= uint64(branchFactor) || (i != 0 && index <= previousChild) {
if (i != 0 && index <= previousChild) || index > math.MaxUint8 {
return errChildIndexTooLarge
}
previousChild = index

compressedKey, err := c.decodeKey(src, branchFactor)
compressedKey, err := c.decodeKey(src)
if err != nil {
return err
}
Expand Down Expand Up @@ -331,11 +331,11 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
}

func (c *codecImpl) encodeKey(dst *bytes.Buffer, key Key) {
c.encodeUint(dst, uint64(key.tokenLength))
c.encodeUint(dst, uint64(key.length))
_, _ = dst.Write(key.Bytes())
}

func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key, error) {
func (c *codecImpl) decodeKey(src *bytes.Reader) (Key, error) {
if minKeyLen > src.Len() {
return Key{}, io.ErrUnexpectedEOF
}
Expand All @@ -347,9 +347,10 @@ func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key
if length > math.MaxInt {
return Key{}, errIntOverflow
}
result := emptyKey(branchFactor)
result.tokenLength = int(length)
keyBytesLen := result.bytesNeeded(result.tokenLength)
result := Key{
length: int(length),
}
keyBytesLen := bytesNeeded(result.length)
if keyBytesLen > src.Len() {
return Key{}, io.ErrUnexpectedEOF
}
Expand All @@ -363,8 +364,8 @@ func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key
if result.hasPartialByte() {
// Confirm that the padding bits in the partial byte are 0.
// We want to only look at the bits to the right of the last token, which is at index length-1.
// Generate a mask with (8-bitsToShift) 0s followed by bitsToShift 1s.
paddingMask := byte(0xFF >> (8 - result.bitsToShift(result.tokenLength-1)))
// Generate a mask where the (result.length % 8) left bits are 0.
paddingMask := byte(0xFF >> (result.length % 8))
if buffer[keyBytesLen-1]&paddingMask != 0 {
return Key{}, errNonZeroKeyPadding
}
Expand Down
98 changes: 40 additions & 58 deletions x/merkledb/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,22 @@ func FuzzCodecKey(f *testing.F) {
b []byte,
) {
require := require.New(t)
for _, branchFactor := range branchFactors {
codec := codec.(*codecImpl)
reader := bytes.NewReader(b)
startLen := reader.Len()
got, err := codec.decodeKey(reader, branchFactor)
if err != nil {
t.SkipNow()
}
endLen := reader.Len()
numRead := startLen - endLen

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
codec.encodeKey(&buf, got)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
codec := codec.(*codecImpl)
reader := bytes.NewReader(b)
startLen := reader.Len()
got, err := codec.decodeKey(reader)
if err != nil {
t.SkipNow()
}
endLen := reader.Len()
numRead := startLen - endLen

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
codec.encodeKey(&buf, got)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
},
)
}
Expand All @@ -109,17 +107,15 @@ func FuzzCodecDBNodeCanonical(f *testing.F) {
b []byte,
) {
require := require.New(t)
for _, branchFactor := range branchFactors {
codec := codec.(*codecImpl)
node := &dbNode{}
if err := codec.decodeDBNode(b, node, branchFactor); err != nil {
t.SkipNow()
}

// Encoding [node] should be the same as [b].
buf := codec.encodeDBNode(node)
require.Equal(b, buf)
codec := codec.(*codecImpl)
node := &dbNode{}
if err := codec.decodeDBNode(b, node); err != nil {
t.SkipNow()
}

// Encoding [node] should be the same as [b].
buf := codec.encodeDBNode(node)
require.Equal(b, buf)
},
)
}
Expand All @@ -133,7 +129,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
valueBytes []byte,
) {
require := require.New(t)
for _, branchFactor := range branchFactors {
for _, bf := range validBranchFactors {
r := rand.New(rand.NewSource(int64(randSeed))) // #nosec G404

value := maybe.Nothing[[]byte]()
Expand All @@ -148,7 +144,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
value = maybe.Some(valueBytes)
}

numChildren := r.Intn(int(branchFactor)) // #nosec G404
numChildren := r.Intn(int(bf)) // #nosec G404

children := map[byte]child{}
for i := 0; i < numChildren; i++ {
Expand All @@ -159,7 +155,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
_, _ = r.Read(childKeyBytes) // #nosec G404

children[byte(i)] = child{
compressedKey: ToKey(childKeyBytes, branchFactor),
compressedKey: ToKey(childKeyBytes),
id: childID,
}
}
Expand All @@ -171,7 +167,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
nodeBytes := codec.encodeDBNode(&node)

var gotNode dbNode
require.NoError(codec.decodeDBNode(nodeBytes, &gotNode, branchFactor))
require.NoError(codec.decodeDBNode(nodeBytes, &gotNode))
require.Equal(node, gotNode)

nodeBytes2 := codec.encodeDBNode(&gotNode)
Expand All @@ -181,31 +177,15 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
)
}

func TestCodecDecodeDBNode(t *testing.T) {
func TestCodecDecodeDBNode_TooShort(t *testing.T) {
require := require.New(t)

var (
parsedDBNode dbNode
tooShortBytes = make([]byte, minDBNodeLen-1)
)
err := codec.decodeDBNode(tooShortBytes, &parsedDBNode, BranchFactor16)
err := codec.decodeDBNode(tooShortBytes, &parsedDBNode)
require.ErrorIs(err, io.ErrUnexpectedEOF)

proof := dbNode{
value: maybe.Some([]byte{1}),
children: map[byte]child{},
}

nodeBytes := codec.encodeDBNode(&proof)
// Remove num children (0) from end
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf := bytes.NewBuffer(nodeBytes)

// Put num children > branch factor
codec.(*codecImpl).encodeUint(proofBytesBuf, uint64(BranchFactor16+1))

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode, BranchFactor16)
require.ErrorIs(err, errTooManyChildren)
}

// Ensure that encodeHashValues is deterministic
Expand All @@ -219,18 +199,18 @@ func FuzzEncodeHashValues(f *testing.F) {
randSeed int,
) {
require := require.New(t)
for _, branchFactor := range branchFactors { // Create a random *hashValues
for _, bf := range validBranchFactors { // Create a random node
r := rand.New(rand.NewSource(int64(randSeed))) // #nosec G404

children := map[byte]child{}
numChildren := r.Intn(int(branchFactor)) // #nosec G404
numChildren := r.Intn(int(bf)) // #nosec G404
for i := 0; i < numChildren; i++ {
compressedKeyLen := r.Intn(32) // #nosec G404
compressedKeyBytes := make([]byte, compressedKeyLen)
_, _ = r.Read(compressedKeyBytes) // #nosec G404

children[byte(i)] = child{
compressedKey: ToKey(compressedKeyBytes, branchFactor),
compressedKey: ToKey(compressedKeyBytes),
id: ids.GenerateTestID(),
hasValue: r.Intn(2) == 1, // #nosec G404
}
Expand All @@ -247,13 +227,15 @@ func FuzzEncodeHashValues(f *testing.F) {
key := make([]byte, r.Intn(32)) // #nosec G404
_, _ = r.Read(key) // #nosec G404

hv := &hashValues{
Children: children,
Value: value,
Key: ToKey(key, branchFactor),
hv := &node{
key: ToKey(key),
dbNode: dbNode{
children: children,
value: value,
},
}

// Serialize the *hashValues with both codecs
// Serialize hv with both codecs
hvBytes1 := codec1.encodeHashValues(hv)
hvBytes2 := codec2.encodeHashValues(hv)

Expand All @@ -267,6 +249,6 @@ func FuzzEncodeHashValues(f *testing.F) {
func TestCodecDecodeKeyLengthOverflowRegression(t *testing.T) {
codec := codec.(*codecImpl)
bytes := bytes.NewReader(binary.AppendUvarint(nil, math.MaxInt))
_, err := codec.decodeKey(bytes, BranchFactor16)
_, err := codec.decodeKey(bytes)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
}
Loading

0 comments on commit e710899

Please sign in to comment.