Skip to content

Commit

Permalink
Remove merkledb codec struct (#2883)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Mar 29, 2024
1 parent 10b881f commit b01d98d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 126 deletions.
144 changes: 59 additions & 85 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ const (
)

var (
_ encoderDecoder = (*codecImpl)(nil)

trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

Expand All @@ -49,131 +47,107 @@ var (
errIntOverflow = errors.New("value overflows int")
)

// encoderDecoder defines the interface needed by merkleDB to marshal
// and unmarshal relevant types.
type encoderDecoder interface {
encoder
decoder
}

type encoder interface {
// Assumes [n] is non-nil.
encodeDBNode(n *dbNode) []byte
encodedDBNodeSize(n *dbNode) int

// Returns the bytes that will be hashed to generate [n]'s ID.
// Assumes [n] is non-nil.
encodeHashValues(n *node) []byte
encodeKey(key Key) []byte
}

type decoder interface {
// Assumes [n] is non-nil.
decodeDBNode(bytes []byte, n *dbNode) error
decodeKey(bytes []byte) (Key, error)
}

func newCodec() encoderDecoder {
return &codecImpl{}
}

// Note that bytes.Buffer.Write always returns nil, so we
// can ignore its return values in [codecImpl] methods.
type codecImpl struct{}
// Note that bytes.Buffer.Write always returns nil, so we ignore its return
// values in all encode methods.

func (c *codecImpl) childSize(index byte, childEntry *child) int {
func childSize(index byte, childEntry *child) int {
// * index
// * child ID
// * child key
// * bool indicating whether the child has a value
return c.uintSize(uint64(index)) + ids.IDLen + c.keySize(childEntry.compressedKey) + boolLen
return uintSize(uint64(index)) + ids.IDLen + keySize(childEntry.compressedKey) + boolLen
}

// based on the current implementation of codecImpl.encodeUint which uses binary.PutUvarint
func (*codecImpl) uintSize(value uint64) int {
// based on the implementation of encodeUint which uses binary.PutUvarint
func uintSize(value uint64) int {
if value == 0 {
return 1
}
return (bits.Len64(value) + 6) / 7
}

func (c *codecImpl) keySize(p Key) int {
return c.uintSize(uint64(p.length)) + bytesNeeded(p.length)
func keySize(p Key) int {
return uintSize(uint64(p.length)) + bytesNeeded(p.length)
}

func (c *codecImpl) encodedDBNodeSize(n *dbNode) int {
// Assumes [n] is non-nil.
func encodedDBNodeSize(n *dbNode) int {
// * number of children
// * bool indicating whether [n] has a value
// * the value (optional)
// * children
size := c.uintSize(uint64(len(n.children))) + boolLen
size := uintSize(uint64(len(n.children))) + boolLen
if n.value.HasValue() {
valueLen := len(n.value.Value())
size += c.uintSize(uint64(valueLen)) + valueLen
size += uintSize(uint64(valueLen)) + valueLen
}
// for each non-nil entry, we add the additional size of the child entry
for index, entry := range n.children {
size += c.childSize(index, entry)
size += childSize(index, entry)
}
return size
}

func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
buf := bytes.NewBuffer(make([]byte, 0, c.encodedDBNodeSize(n)))
c.encodeMaybeByteSlice(buf, n.value)
c.encodeUint(buf, uint64(len(n.children)))
// Assumes [n] is non-nil.
func encodeDBNode(n *dbNode) []byte {
buf := bytes.NewBuffer(make([]byte, 0, encodedDBNodeSize(n)))
encodeMaybeByteSlice(buf, n.value)
encodeUint(buf, uint64(len(n.children)))
// Note we insert children in order of increasing index
// for determinism.
keys := maps.Keys(n.children)
slices.Sort(keys)
for _, index := range keys {
entry := n.children[index]
c.encodeUint(buf, uint64(index))
c.encodeKeyToBuffer(buf, entry.compressedKey)
encodeUint(buf, uint64(index))
encodeKeyToBuffer(buf, entry.compressedKey)
_, _ = buf.Write(entry.id[:])
c.encodeBool(buf, entry.hasValue)
encodeBool(buf, entry.hasValue)
}
return buf.Bytes()
}

func (c *codecImpl) encodeHashValues(n *node) []byte {
// Returns the bytes that will be hashed to generate [n]'s ID.
// Assumes [n] is non-nil.
func encodeHashValues(n *node) []byte {
var (
numChildren = len(n.children)
// Estimate size [hv] to prevent memory allocations
estimatedLen = minVarIntLen + numChildren*hashValuesChildLen + estimatedValueLen + estimatedKeyLen
buf = bytes.NewBuffer(make([]byte, 0, estimatedLen))
)

c.encodeUint(buf, uint64(numChildren))
encodeUint(buf, uint64(numChildren))

// ensure that the order of entries is consistent
keys := maps.Keys(n.children)
slices.Sort(keys)
for _, index := range keys {
entry := n.children[index]
c.encodeUint(buf, uint64(index))
encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
c.encodeMaybeByteSlice(buf, n.valueDigest)
c.encodeKeyToBuffer(buf, n.key)
encodeMaybeByteSlice(buf, n.valueDigest)
encodeKeyToBuffer(buf, n.key)

return buf.Bytes()
}

func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
// Assumes [n] is non-nil.
func decodeDBNode(b []byte, n *dbNode) error {
if minDBNodeLen > len(b) {
return io.ErrUnexpectedEOF
}

src := bytes.NewReader(b)

value, err := c.decodeMaybeByteSlice(src)
value, err := decodeMaybeByteSlice(src)
if err != nil {
return err
}
n.value = value

numChildren, err := c.decodeUint(src)
numChildren, err := decodeUint(src)
switch {
case err != nil:
return err
Expand All @@ -184,7 +158,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
n.children = make(map[byte]*child, numChildren)
var previousChild uint64
for i := uint64(0); i < numChildren; i++ {
index, err := c.decodeUint(src)
index, err := decodeUint(src)
if err != nil {
return err
}
Expand All @@ -193,15 +167,15 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
}
previousChild = index

compressedKey, err := c.decodeKeyFromReader(src)
compressedKey, err := decodeKeyFromReader(src)
if err != nil {
return err
}
childID, err := c.decodeID(src)
childID, err := decodeID(src)
if err != nil {
return err
}
hasValue, err := c.decodeBool(src)
hasValue, err := decodeBool(src)
if err != nil {
return err
}
Expand All @@ -217,15 +191,15 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
return nil
}

func (*codecImpl) encodeBool(dst *bytes.Buffer, value bool) {
func encodeBool(dst *bytes.Buffer, value bool) {
bytesValue := falseBytes
if value {
bytesValue = trueBytes
}
_, _ = dst.Write(bytesValue)
}

func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
func decodeBool(src *bytes.Reader) (bool, error) {
boolByte, err := src.ReadByte()
switch {
case err == io.EOF:
Expand All @@ -241,7 +215,7 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
}
}

func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) {
func decodeUint(src *bytes.Reader) (uint64, error) {
// To ensure encoding/decoding is canonical, we need to check for leading
// zeroes in the varint.
// The last byte of the varint we read is the most significant byte.
Expand Down Expand Up @@ -274,43 +248,43 @@ func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) {
return val64, nil
}

func (*codecImpl) encodeUint(dst *bytes.Buffer, value uint64) {
func encodeUint(dst *bytes.Buffer, value uint64) {
var buf [binary.MaxVarintLen64]byte
size := binary.PutUvarint(buf[:], value)
_, _ = dst.Write(buf[:size])
}

func (c *codecImpl) encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) {
func encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) {
hasValue := maybeValue.HasValue()
c.encodeBool(dst, hasValue)
encodeBool(dst, hasValue)
if hasValue {
c.encodeByteSlice(dst, maybeValue.Value())
encodeByteSlice(dst, maybeValue.Value())
}
}

func (c *codecImpl) decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) {
func decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) {
if minMaybeByteSliceLen > src.Len() {
return maybe.Nothing[[]byte](), io.ErrUnexpectedEOF
}

if hasValue, err := c.decodeBool(src); err != nil || !hasValue {
if hasValue, err := decodeBool(src); err != nil || !hasValue {
return maybe.Nothing[[]byte](), err
}

rawBytes, err := c.decodeByteSlice(src)
rawBytes, err := decodeByteSlice(src)
if err != nil {
return maybe.Nothing[[]byte](), err
}

return maybe.Some(rawBytes), nil
}

func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
func decodeByteSlice(src *bytes.Reader) ([]byte, error) {
if minByteSliceLen > src.Len() {
return nil, io.ErrUnexpectedEOF
}

length, err := c.decodeUint(src)
length, err := decodeUint(src)
switch {
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
Expand All @@ -330,14 +304,14 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
return result, err
}

func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) {
c.encodeUint(dst, uint64(len(value)))
func encodeByteSlice(dst *bytes.Buffer, value []byte) {
encodeUint(dst, uint64(len(value)))
if value != nil {
_, _ = dst.Write(value)
}
}

func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
func decodeID(src *bytes.Reader) (ids.ID, error) {
if ids.IDLen > src.Len() {
return ids.ID{}, io.ErrUnexpectedEOF
}
Expand All @@ -350,21 +324,21 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
return id, err
}

func (c *codecImpl) encodeKey(key Key) []byte {
func encodeKey(key Key) []byte {
estimatedLen := binary.MaxVarintLen64 + len(key.Bytes())
dst := bytes.NewBuffer(make([]byte, 0, estimatedLen))
c.encodeKeyToBuffer(dst, key)
encodeKeyToBuffer(dst, key)
return dst.Bytes()
}

func (c *codecImpl) encodeKeyToBuffer(dst *bytes.Buffer, key Key) {
c.encodeUint(dst, uint64(key.length))
func encodeKeyToBuffer(dst *bytes.Buffer, key Key) {
encodeUint(dst, uint64(key.length))
_, _ = dst.Write(key.Bytes())
}

func (c *codecImpl) decodeKey(b []byte) (Key, error) {
func decodeKey(b []byte) (Key, error) {
src := bytes.NewReader(b)
key, err := c.decodeKeyFromReader(src)
key, err := decodeKeyFromReader(src)
if err != nil {
return Key{}, err
}
Expand All @@ -374,12 +348,12 @@ func (c *codecImpl) decodeKey(b []byte) (Key, error) {
return key, err
}

func (c *codecImpl) decodeKeyFromReader(src *bytes.Reader) (Key, error) {
func decodeKeyFromReader(src *bytes.Reader) (Key, error) {
if minKeyLen > src.Len() {
return Key{}, io.ErrUnexpectedEOF
}

length, err := c.decodeUint(src)
length, err := decodeUint(src)
if err != nil {
return Key{}, err
}
Expand Down
Loading

0 comments on commit b01d98d

Please sign in to comment.