From 6561b9c680c65acf79c4a3b26ecda24b33b0a4d7 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 4 Oct 2022 11:11:00 +0000 Subject: [PATCH] fix(trie): `PopulateMerkleValues` behavior - Include Merkle value of the parent node passed as argument - Calculate missing Merkle values if needed - Add unit test --- dot/state/offline_pruner.go | 5 +- lib/trie/database.go | 41 ++++++++++++---- lib/trie/database_test.go | 95 ++++++++++++++++++++++++++++++++++++- 3 files changed, 130 insertions(+), 11 deletions(-) diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index 1c9d03d9971..6e163139c09 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -121,7 +121,10 @@ func (p *OfflinePruner) SetBloomFilter() (err error) { return err } - tr.PopulateMerkleValues(tr.RootNode(), merkleValues) + err = tr.PopulateMerkleValues(tr.RootNode(), merkleValues) + if err != nil { + return fmt.Errorf("populating Merkle values from trie: %w", err) + } // get parent header of current block header, err = p.blockState.GetHeader(header.ParentHash) diff --git a/lib/trie/database.go b/lib/trie/database.go index fe908df6ecb..bd0501d34c1 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -187,21 +187,44 @@ func (t *Trie) loadNode(db Database, n *Node) error { // PopulateMerkleValues writes the Merkle value of each children of the node given // as keys to the map merkleValues. -func (t *Trie) PopulateMerkleValues(n *Node, merkleValues map[string]struct{}) { - if n.Kind() != node.Branch { - return +func (t *Trie) PopulateMerkleValues(n *Node, + merkleValues map[string]struct{}) (err error) { + if n == nil { + return nil } - branch := n - for _, child := range branch.Children { - if child == nil { - continue + merkleValue := n.MerkleValue + if len(merkleValue) == 0 { + // Compute and cache node Merkle value if it is absent. + if n == t.root { + merkleValue, err = n.CalculateRootMerkleValue() + if err != nil { + return fmt.Errorf("calculating Merkle value for root node: %w", err) + } + } else { + merkleValue, err = n.CalculateMerkleValue() + if err != nil { + return fmt.Errorf("calculating Merkle value for node: %w", err) + } } + } - merkleValues[string(child.MerkleValue)] = struct{}{} + merkleValues[string(merkleValue)] = struct{}{} + + if n.Kind() == node.Leaf { + return nil + } - t.PopulateMerkleValues(child, merkleValues) + branch := n + for _, child := range branch.Children { + err = t.PopulateMerkleValues(child, merkleValues) + if err != nil { + // Note: do not wrap error since this is recursive. + return err + } } + + return nil } // GetFromDB retrieves a value at the given key from the trie using the database. diff --git a/lib/trie/database_test.go b/lib/trie/database_test.go index 41fb5379539..cfca9497dde 100644 --- a/lib/trie/database_test.go +++ b/lib/trie/database_test.go @@ -158,7 +158,100 @@ func Test_Trie_WriteDirty_ClearPrefix(t *testing.T) { assert.Equal(t, trie.String(), trieFromDB.String()) } -func Test_Trie_GetFromDB(t *testing.T) { +func Test_PopulateMerkleValues(t *testing.T) { + t.Parallel() + + someNode := &Node{Key: []byte{1}, SubValue: []byte{2}} + + testCases := map[string]struct { + trie *Trie + node *Node + merkleValues map[string]struct{} + errSentinel error + errMessage string + }{ + "nil node": { + trie: &Trie{}, + merkleValues: map[string]struct{}{}, + }, + "leaf node": { + trie: &Trie{}, + node: &Node{MerkleValue: []byte("a")}, + merkleValues: map[string]struct{}{ + "a": {}, + }, + }, + "leaf node without Merkle value": { + trie: &Trie{}, + node: &Node{Key: []byte{1}, SubValue: []byte{2}}, + merkleValues: map[string]struct{}{ + "A\x01\x04\x02": {}, + }, + }, + "root leaf node without Merkle value": { + trie: &Trie{ + root: someNode, + }, + node: someNode, + merkleValues: map[string]struct{}{ + "`Qm\v\xb6\xe1\xbb\xfb\x12\x93\xf1\xb2v\xea\x95\x05\xe9\xf4\xa4\xe7ُb\r\x05\x11^\v\x85'J\xe1": {}, + }, + }, + "branch node": { + trie: &Trie{}, + node: &Node{ + MerkleValue: []byte("a"), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte("b")}, + }), + }, + merkleValues: map[string]struct{}{ + "a": {}, + "b": {}, + }, + }, + "nested branch node": { + trie: &Trie{}, + node: &Node{ + MerkleValue: []byte("a"), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte("b")}, + { + MerkleValue: []byte("c"), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte("d")}, + }), + }, + }), + }, + merkleValues: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + "d": {}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + merkleValues := make(map[string]struct{}) + + err := testCase.trie.PopulateMerkleValues(testCase.node, merkleValues) + + assert.ErrorIs(t, err, testCase.errSentinel) + if testCase.errSentinel != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.merkleValues, merkleValues) + }) + } +} + +func Test_GetFromDB(t *testing.T) { t.Parallel() const size = 1000