From b01d98d3441460e590916cbe6c512c9b0000ab5f Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Fri, 29 Mar 2024 00:43:33 -0400 Subject: [PATCH] Remove merkledb codec struct (#2883) --- x/merkledb/codec.go | 144 ++++++++++++++++----------------------- x/merkledb/codec_test.go | 54 ++++++--------- x/merkledb/db.go | 8 +-- x/merkledb/node.go | 6 +- 4 files changed, 86 insertions(+), 126 deletions(-) diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index 3161fd16c86f..d98646a1f267 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -36,8 +36,6 @@ const ( ) var ( - _ encoderDecoder = (*codecImpl)(nil) - trueBytes = []byte{trueByte} falseBytes = []byte{falseByte} @@ -49,94 +47,69 @@ var ( errIntOverflow = errors.New("value overflows int") ) -// encoderDecoder defines the interface needed by merkleDB to marshal -// and unmarshal relevant types. -type encoderDecoder interface { - encoder - decoder -} - -type encoder interface { - // Assumes [n] is non-nil. - encodeDBNode(n *dbNode) []byte - encodedDBNodeSize(n *dbNode) int - - // Returns the bytes that will be hashed to generate [n]'s ID. - // Assumes [n] is non-nil. - encodeHashValues(n *node) []byte - encodeKey(key Key) []byte -} - -type decoder interface { - // Assumes [n] is non-nil. - decodeDBNode(bytes []byte, n *dbNode) error - decodeKey(bytes []byte) (Key, error) -} - -func newCodec() encoderDecoder { - return &codecImpl{} -} - -// Note that bytes.Buffer.Write always returns nil, so we -// can ignore its return values in [codecImpl] methods. -type codecImpl struct{} +// Note that bytes.Buffer.Write always returns nil, so we ignore its return +// values in all encode methods. -func (c *codecImpl) childSize(index byte, childEntry *child) int { +func childSize(index byte, childEntry *child) int { // * index // * child ID // * child key // * bool indicating whether the child has a value - return c.uintSize(uint64(index)) + ids.IDLen + c.keySize(childEntry.compressedKey) + boolLen + return uintSize(uint64(index)) + ids.IDLen + keySize(childEntry.compressedKey) + boolLen } -// based on the current implementation of codecImpl.encodeUint which uses binary.PutUvarint -func (*codecImpl) uintSize(value uint64) int { +// based on the implementation of encodeUint which uses binary.PutUvarint +func uintSize(value uint64) int { if value == 0 { return 1 } return (bits.Len64(value) + 6) / 7 } -func (c *codecImpl) keySize(p Key) int { - return c.uintSize(uint64(p.length)) + bytesNeeded(p.length) +func keySize(p Key) int { + return uintSize(uint64(p.length)) + bytesNeeded(p.length) } -func (c *codecImpl) encodedDBNodeSize(n *dbNode) int { +// Assumes [n] is non-nil. +func encodedDBNodeSize(n *dbNode) int { // * number of children // * bool indicating whether [n] has a value // * the value (optional) // * children - size := c.uintSize(uint64(len(n.children))) + boolLen + size := uintSize(uint64(len(n.children))) + boolLen if n.value.HasValue() { valueLen := len(n.value.Value()) - size += c.uintSize(uint64(valueLen)) + valueLen + size += uintSize(uint64(valueLen)) + valueLen } // for each non-nil entry, we add the additional size of the child entry for index, entry := range n.children { - size += c.childSize(index, entry) + size += childSize(index, entry) } return size } -func (c *codecImpl) encodeDBNode(n *dbNode) []byte { - buf := bytes.NewBuffer(make([]byte, 0, c.encodedDBNodeSize(n))) - c.encodeMaybeByteSlice(buf, n.value) - c.encodeUint(buf, uint64(len(n.children))) +// Assumes [n] is non-nil. +func encodeDBNode(n *dbNode) []byte { + buf := bytes.NewBuffer(make([]byte, 0, encodedDBNodeSize(n))) + encodeMaybeByteSlice(buf, n.value) + encodeUint(buf, uint64(len(n.children))) // Note we insert children in order of increasing index // for determinism. keys := maps.Keys(n.children) slices.Sort(keys) for _, index := range keys { entry := n.children[index] - c.encodeUint(buf, uint64(index)) - c.encodeKeyToBuffer(buf, entry.compressedKey) + encodeUint(buf, uint64(index)) + encodeKeyToBuffer(buf, entry.compressedKey) _, _ = buf.Write(entry.id[:]) - c.encodeBool(buf, entry.hasValue) + encodeBool(buf, entry.hasValue) } return buf.Bytes() } -func (c *codecImpl) encodeHashValues(n *node) []byte { +// Returns the bytes that will be hashed to generate [n]'s ID. +// Assumes [n] is non-nil. +func encodeHashValues(n *node) []byte { var ( numChildren = len(n.children) // Estimate size [hv] to prevent memory allocations @@ -144,36 +117,37 @@ func (c *codecImpl) encodeHashValues(n *node) []byte { buf = bytes.NewBuffer(make([]byte, 0, estimatedLen)) ) - c.encodeUint(buf, uint64(numChildren)) + encodeUint(buf, uint64(numChildren)) // ensure that the order of entries is consistent keys := maps.Keys(n.children) slices.Sort(keys) for _, index := range keys { entry := n.children[index] - c.encodeUint(buf, uint64(index)) + encodeUint(buf, uint64(index)) _, _ = buf.Write(entry.id[:]) } - c.encodeMaybeByteSlice(buf, n.valueDigest) - c.encodeKeyToBuffer(buf, n.key) + encodeMaybeByteSlice(buf, n.valueDigest) + encodeKeyToBuffer(buf, n.key) return buf.Bytes() } -func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { +// Assumes [n] is non-nil. +func decodeDBNode(b []byte, n *dbNode) error { if minDBNodeLen > len(b) { return io.ErrUnexpectedEOF } src := bytes.NewReader(b) - value, err := c.decodeMaybeByteSlice(src) + value, err := decodeMaybeByteSlice(src) if err != nil { return err } n.value = value - numChildren, err := c.decodeUint(src) + numChildren, err := decodeUint(src) switch { case err != nil: return err @@ -184,7 +158,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { n.children = make(map[byte]*child, numChildren) var previousChild uint64 for i := uint64(0); i < numChildren; i++ { - index, err := c.decodeUint(src) + index, err := decodeUint(src) if err != nil { return err } @@ -193,15 +167,15 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { } previousChild = index - compressedKey, err := c.decodeKeyFromReader(src) + compressedKey, err := decodeKeyFromReader(src) if err != nil { return err } - childID, err := c.decodeID(src) + childID, err := decodeID(src) if err != nil { return err } - hasValue, err := c.decodeBool(src) + hasValue, err := decodeBool(src) if err != nil { return err } @@ -217,7 +191,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { return nil } -func (*codecImpl) encodeBool(dst *bytes.Buffer, value bool) { +func encodeBool(dst *bytes.Buffer, value bool) { bytesValue := falseBytes if value { bytesValue = trueBytes @@ -225,7 +199,7 @@ func (*codecImpl) encodeBool(dst *bytes.Buffer, value bool) { _, _ = dst.Write(bytesValue) } -func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) { +func decodeBool(src *bytes.Reader) (bool, error) { boolByte, err := src.ReadByte() switch { case err == io.EOF: @@ -241,7 +215,7 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) { } } -func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) { +func decodeUint(src *bytes.Reader) (uint64, error) { // To ensure encoding/decoding is canonical, we need to check for leading // zeroes in the varint. // The last byte of the varint we read is the most significant byte. @@ -274,30 +248,30 @@ func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) { return val64, nil } -func (*codecImpl) encodeUint(dst *bytes.Buffer, value uint64) { +func encodeUint(dst *bytes.Buffer, value uint64) { var buf [binary.MaxVarintLen64]byte size := binary.PutUvarint(buf[:], value) _, _ = dst.Write(buf[:size]) } -func (c *codecImpl) encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) { +func encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) { hasValue := maybeValue.HasValue() - c.encodeBool(dst, hasValue) + encodeBool(dst, hasValue) if hasValue { - c.encodeByteSlice(dst, maybeValue.Value()) + encodeByteSlice(dst, maybeValue.Value()) } } -func (c *codecImpl) decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) { +func decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) { if minMaybeByteSliceLen > src.Len() { return maybe.Nothing[[]byte](), io.ErrUnexpectedEOF } - if hasValue, err := c.decodeBool(src); err != nil || !hasValue { + if hasValue, err := decodeBool(src); err != nil || !hasValue { return maybe.Nothing[[]byte](), err } - rawBytes, err := c.decodeByteSlice(src) + rawBytes, err := decodeByteSlice(src) if err != nil { return maybe.Nothing[[]byte](), err } @@ -305,12 +279,12 @@ func (c *codecImpl) decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte] return maybe.Some(rawBytes), nil } -func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) { +func decodeByteSlice(src *bytes.Reader) ([]byte, error) { if minByteSliceLen > src.Len() { return nil, io.ErrUnexpectedEOF } - length, err := c.decodeUint(src) + length, err := decodeUint(src) switch { case err == io.EOF: return nil, io.ErrUnexpectedEOF @@ -330,14 +304,14 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) { return result, err } -func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) { - c.encodeUint(dst, uint64(len(value))) +func encodeByteSlice(dst *bytes.Buffer, value []byte) { + encodeUint(dst, uint64(len(value))) if value != nil { _, _ = dst.Write(value) } } -func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) { +func decodeID(src *bytes.Reader) (ids.ID, error) { if ids.IDLen > src.Len() { return ids.ID{}, io.ErrUnexpectedEOF } @@ -350,21 +324,21 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) { return id, err } -func (c *codecImpl) encodeKey(key Key) []byte { +func encodeKey(key Key) []byte { estimatedLen := binary.MaxVarintLen64 + len(key.Bytes()) dst := bytes.NewBuffer(make([]byte, 0, estimatedLen)) - c.encodeKeyToBuffer(dst, key) + encodeKeyToBuffer(dst, key) return dst.Bytes() } -func (c *codecImpl) encodeKeyToBuffer(dst *bytes.Buffer, key Key) { - c.encodeUint(dst, uint64(key.length)) +func encodeKeyToBuffer(dst *bytes.Buffer, key Key) { + encodeUint(dst, uint64(key.length)) _, _ = dst.Write(key.Bytes()) } -func (c *codecImpl) decodeKey(b []byte) (Key, error) { +func decodeKey(b []byte) (Key, error) { src := bytes.NewReader(b) - key, err := c.decodeKeyFromReader(src) + key, err := decodeKeyFromReader(src) if err != nil { return Key{}, err } @@ -374,12 +348,12 @@ func (c *codecImpl) decodeKey(b []byte) (Key, error) { return key, err } -func (c *codecImpl) decodeKeyFromReader(src *bytes.Reader) (Key, error) { +func decodeKeyFromReader(src *bytes.Reader) (Key, error) { if minKeyLen > src.Len() { return Key{}, io.ErrUnexpectedEOF } - length, err := c.decodeUint(src) + length, err := decodeUint(src) if err != nil { return Key{}, err } diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index 7d8e6028cfe2..685c6453471a 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -26,10 +26,9 @@ func FuzzCodecBool(f *testing.F) { ) { require := require.New(t) - codec := codec.(*codecImpl) reader := bytes.NewReader(b) startLen := reader.Len() - got, err := codec.decodeBool(reader) + got, err := decodeBool(reader) if err != nil { t.SkipNow() } @@ -38,7 +37,7 @@ func FuzzCodecBool(f *testing.F) { // Encoding [got] should be the same as [b]. var buf bytes.Buffer - codec.encodeBool(&buf, got) + encodeBool(&buf, got) bufBytes := buf.Bytes() require.Len(bufBytes, numRead) require.Equal(b[:numRead], bufBytes) @@ -54,10 +53,9 @@ func FuzzCodecInt(f *testing.F) { ) { require := require.New(t) - codec := codec.(*codecImpl) reader := bytes.NewReader(b) startLen := reader.Len() - got, err := codec.decodeUint(reader) + got, err := decodeUint(reader) if err != nil { t.SkipNow() } @@ -66,7 +64,7 @@ func FuzzCodecInt(f *testing.F) { // Encoding [got] should be the same as [b]. var buf bytes.Buffer - codec.encodeUint(&buf, got) + encodeUint(&buf, got) bufBytes := buf.Bytes() require.Len(bufBytes, numRead) require.Equal(b[:numRead], bufBytes) @@ -81,14 +79,13 @@ func FuzzCodecKey(f *testing.F) { b []byte, ) { require := require.New(t) - codec := codec.(*codecImpl) - got, err := codec.decodeKey(b) + got, err := decodeKey(b) if err != nil { t.SkipNow() } // Encoding [got] should be the same as [b]. - gotBytes := codec.encodeKey(got) + gotBytes := encodeKey(got) require.Equal(b, gotBytes) }, ) @@ -101,14 +98,13 @@ func FuzzCodecDBNodeCanonical(f *testing.F) { b []byte, ) { require := require.New(t) - codec := codec.(*codecImpl) node := &dbNode{} - if err := codec.decodeDBNode(b, node); err != nil { + if err := decodeDBNode(b, node); err != nil { t.SkipNow() } // Encoding [node] should be the same as [b]. - buf := codec.encodeDBNode(node) + buf := encodeDBNode(node) require.Equal(b, buf) }, ) @@ -158,13 +154,13 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { children: children, } - nodeBytes := codec.encodeDBNode(&node) - require.Len(nodeBytes, codec.encodedDBNodeSize(&node)) + nodeBytes := encodeDBNode(&node) + require.Len(nodeBytes, encodedDBNodeSize(&node)) var gotNode dbNode - require.NoError(codec.decodeDBNode(nodeBytes, &gotNode)) + require.NoError(decodeDBNode(nodeBytes, &gotNode)) require.Equal(node, gotNode) - nodeBytes2 := codec.encodeDBNode(&gotNode) + nodeBytes2 := encodeDBNode(&gotNode) require.Equal(nodeBytes, nodeBytes2) } }, @@ -178,15 +174,12 @@ func TestCodecDecodeDBNode_TooShort(t *testing.T) { parsedDBNode dbNode tooShortBytes = make([]byte, minDBNodeLen-1) ) - err := codec.decodeDBNode(tooShortBytes, &parsedDBNode) + err := decodeDBNode(tooShortBytes, &parsedDBNode) require.ErrorIs(err, io.ErrUnexpectedEOF) } // Ensure that encodeHashValues is deterministic func FuzzEncodeHashValues(f *testing.F) { - codec1 := newCodec() - codec2 := newCodec() - f.Fuzz( func( t *testing.T, @@ -229,9 +222,9 @@ func FuzzEncodeHashValues(f *testing.F) { }, } - // Serialize hv with both codecs - hvBytes1 := codec1.encodeHashValues(hv) - hvBytes2 := codec2.encodeHashValues(hv) + // Serialize hv multiple times + hvBytes1 := encodeHashValues(hv) + hvBytes2 := encodeHashValues(hv) // Make sure they're the same require.Equal(hvBytes1, hvBytes2) @@ -241,43 +234,38 @@ func FuzzEncodeHashValues(f *testing.F) { } func TestCodecDecodeKeyLengthOverflowRegression(t *testing.T) { - codec := codec.(*codecImpl) - _, err := codec.decodeKey(binary.AppendUvarint(nil, math.MaxInt)) + _, err := decodeKey(binary.AppendUvarint(nil, math.MaxInt)) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } func TestUintSize(t *testing.T) { - c := codec.(*codecImpl) - // Test lower bound - expectedSize := c.uintSize(0) + expectedSize := uintSize(0) actualSize := binary.PutUvarint(make([]byte, binary.MaxVarintLen64), 0) require.Equal(t, expectedSize, actualSize) // Test upper bound - expectedSize = c.uintSize(math.MaxUint64) + expectedSize = uintSize(math.MaxUint64) actualSize = binary.PutUvarint(make([]byte, binary.MaxVarintLen64), math.MaxUint64) require.Equal(t, expectedSize, actualSize) // Test powers of 2 for power := 0; power < 64; power++ { n := uint64(1) << uint(power) - expectedSize := c.uintSize(n) + expectedSize := uintSize(n) actualSize := binary.PutUvarint(make([]byte, binary.MaxVarintLen64), n) require.Equal(t, expectedSize, actualSize, power) } } func Benchmark_EncodeUint(b *testing.B) { - c := codec.(*codecImpl) - var dst bytes.Buffer dst.Grow(binary.MaxVarintLen64) for _, v := range []uint64{0, 1, 2, 32, 1024, 32768} { b.Run(strconv.FormatUint(v, 10), func(b *testing.B) { for i := 0; i < b.N; i++ { - c.encodeUint(&dst, v) + encodeUint(&dst, v) dst.Reset() } }) diff --git a/x/merkledb/db.go b/x/merkledb/db.go index 4775c08ac536..b32518782168 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -41,8 +41,6 @@ const ( var ( _ MerkleDB = (*merkleDB)(nil) - codec = newCodec() - metadataPrefix = []byte{0} valueNodePrefix = []byte{1} intermediateNodePrefix = []byte{2} @@ -985,7 +983,7 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *view) error return db.baseDB.Delete(rootDBKey) } - rootKey := codec.encodeKey(db.root.Value().key) + rootKey := encodeKey(db.root.Value().key) return db.baseDB.Put(rootDBKey, rootKey) } @@ -1177,7 +1175,7 @@ func (db *merkleDB) initializeRoot() error { } // Root is on disk. - rootKey, err := codec.decodeKey(rootKeyBytes) + rootKey, err := decodeKey(rootKeyBytes) if err != nil { return err } @@ -1351,5 +1349,5 @@ func cacheEntrySize(key Key, n *node) int { if n == nil { return cacheEntryOverHead + len(key.Bytes()) } - return cacheEntryOverHead + len(key.Bytes()) + codec.encodedDBNodeSize(&n.dbNode) + return cacheEntryOverHead + len(key.Bytes()) + encodedDBNodeSize(&n.dbNode) } diff --git a/x/merkledb/node.go b/x/merkledb/node.go index dd1f2ed65cd2..60b5cecbb3a6 100644 --- a/x/merkledb/node.go +++ b/x/merkledb/node.go @@ -45,7 +45,7 @@ func newNode(key Key) *node { // Parse [nodeBytes] to a node and set its key to [key]. func parseNode(key Key, nodeBytes []byte) (*node, error) { n := dbNode{} - if err := codec.decodeDBNode(nodeBytes, &n); err != nil { + if err := decodeDBNode(nodeBytes, &n); err != nil { return nil, err } result := &node{ @@ -64,13 +64,13 @@ func (n *node) hasValue() bool { // Returns the byte representation of this node. func (n *node) bytes() []byte { - return codec.encodeDBNode(&n.dbNode) + return encodeDBNode(&n.dbNode) } // Returns and caches the ID of this node. func (n *node) calculateID(metrics merkleMetrics) ids.ID { metrics.HashCalculated() - bytes := codec.encodeHashValues(n) + bytes := encodeHashValues(n) return hashing.ComputeHash256Array(bytes) }