diff --git a/dot/state/storage.go b/dot/state/storage.go index b4781b204c..0251d2de99 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -14,6 +14,7 @@ import ( "github.com/ChainSafe/gossamer/lib/common" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/lib/trie" + "github.com/ChainSafe/gossamer/lib/trie/proof" ) // storagePrefix storage key prefix. @@ -301,6 +302,7 @@ func (s *StorageState) LoadCodeHash(hash *common.Hash) (common.Hash, error) { } // GenerateTrieProof returns the proofs related to the keys on the state root trie -func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) { - return trie.GenerateProof(stateRoot[:], keys, s.db) +func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ( + encodedProofNodes [][]byte, err error) { + return proof.Generate(stateRoot[:], keys, s.db) } diff --git a/internal/trie/node/children.go b/internal/trie/node/children.go index b08c711c9e..725366b42e 100644 --- a/internal/trie/node/children.go +++ b/internal/trie/node/children.go @@ -30,3 +30,13 @@ func (n *Node) NumChildren() (count int) { } return count } + +// HasChild returns true if the node has at least one child. +func (n *Node) HasChild() (has bool) { + for _, child := range n.Children { + if child != nil { + return true + } + } + return false +} diff --git a/internal/trie/node/children_test.go b/internal/trie/node/children_test.go index 66a2603009..17ab48b2f1 100644 --- a/internal/trie/node/children_test.go +++ b/internal/trie/node/children_test.go @@ -118,3 +118,42 @@ func Test_Node_NumChildren(t *testing.T) { }) } } + +func Test_Node_HasChild(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + node Node + has bool + }{ + "no child": {}, + "one child at index 0": { + node: Node{ + Children: []*Node{ + {}, + }, + }, + has: true, + }, + "one child at index 1": { + node: Node{ + Children: []*Node{ + nil, + {}, + }, + }, + has: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + has := testCase.node.HasChild() + + assert.Equal(t, testCase.has, has) + }) + } +} diff --git a/internal/trie/record/node.go b/internal/trie/record/node.go deleted file mode 100644 index 19a745c82c..0000000000 --- a/internal/trie/record/node.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package record - -// Node represents a record of a visited node -type Node struct { - RawData []byte - Hash []byte -} diff --git a/internal/trie/record/recorder.go b/internal/trie/record/recorder.go deleted file mode 100644 index 130b434338..0000000000 --- a/internal/trie/record/recorder.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package record - -// Recorder records the list of nodes found by Lookup.Find -type Recorder struct { - nodes []Node -} - -// NewRecorder creates a new recorder. -func NewRecorder() *Recorder { - return &Recorder{} -} - -// Record appends a node to the list of visited nodes. -func (r *Recorder) Record(hash, rawData []byte) { - r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash}) -} - -// GetNodes returns all the nodes recorded. -// Note it does not copy its slice of nodes. -// It's fine to not copy them since the recorder -// is not used again after a call to GetNodes() -func (r *Recorder) GetNodes() (nodes []Node) { - return r.nodes -} diff --git a/internal/trie/record/recorder_test.go b/internal/trie/record/recorder_test.go deleted file mode 100644 index cdf0ed3eaa..0000000000 --- a/internal/trie/record/recorder_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package record - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_NewRecorder(t *testing.T) { - t.Parallel() - - expected := &Recorder{} - - recorder := NewRecorder() - - assert.Equal(t, expected, recorder) -} - -func Test_Recorder_Record(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - recorder *Recorder - hash []byte - rawData []byte - expectedRecorder *Recorder - }{ - "nil data": { - recorder: &Recorder{}, - expectedRecorder: &Recorder{ - nodes: []Node{ - {}, - }, - }, - }, - "insert in empty recorder": { - recorder: &Recorder{}, - hash: []byte{1, 2}, - rawData: []byte{3, 4}, - expectedRecorder: &Recorder{ - nodes: []Node{ - {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - }, - }, - }, - "insert in non-empty recorder": { - recorder: &Recorder{ - nodes: []Node{ - {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, - }, - }, - hash: []byte{1, 2}, - rawData: []byte{3, 4}, - expectedRecorder: &Recorder{ - nodes: []Node{ - {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, - {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - }, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - testCase.recorder.Record(testCase.hash, testCase.rawData) - - assert.Equal(t, testCase.expectedRecorder, testCase.recorder) - }) - } -} - -func Test_Recorder_GetNodes(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - recorder *Recorder - nodes []Node - }{ - "no node": { - recorder: &Recorder{}, - }, - "get single node from recorder": { - recorder: &Recorder{ - nodes: []Node{ - {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - }, - }, - nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}}, - }, - "get node from multiple nodes in recorder": { - recorder: &Recorder{ - nodes: []Node{ - {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, - {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, - }, - }, - nodes: []Node{ - {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, - {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - nodes := testCase.recorder.GetNodes() - - assert.Equal(t, testCase.nodes, nodes) - }) - } -} diff --git a/lib/runtime/wasmer/imports.go b/lib/runtime/wasmer/imports.go index d62dcfcf26..509c5a3e34 100644 --- a/lib/runtime/wasmer/imports.go +++ b/lib/runtime/wasmer/imports.go @@ -121,6 +121,7 @@ import ( rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/lib/transaction" "github.com/ChainSafe/gossamer/lib/trie" + "github.com/ChainSafe/gossamer/lib/trie/proof" "github.com/ChainSafe/gossamer/pkg/scale" wasm "github.com/wasmerio/go-ext-wasm/wasmer" @@ -886,8 +887,8 @@ func ext_trie_blake2_256_verify_proof_version_1(context unsafe.Pointer, rootSpan instanceContext := wasm.IntoInstanceContext(context) toDecProofs := asMemorySlice(instanceContext, proofSpan) - var decProofs [][]byte - err := scale.Unmarshal(toDecProofs, &decProofs) + var encodedProofNodes [][]byte + err := scale.Unmarshal(toDecProofs, &encodedProofNodes) if err != nil { logger.Errorf("[ext_trie_blake2_256_verify_proof_version_1]: %s", err) return C.int32_t(0) @@ -899,18 +900,13 @@ func ext_trie_blake2_256_verify_proof_version_1(context unsafe.Pointer, rootSpan mem := instanceContext.Memory().Data() trieRoot := mem[rootSpan : rootSpan+32] - exists, err := trie.VerifyProof(decProofs, trieRoot, []trie.Pair{{Key: key, Value: value}}) + err = proof.Verify(encodedProofNodes, trieRoot, key, value) if err != nil { logger.Errorf("[ext_trie_blake2_256_verify_proof_version_1]: %s", err) return C.int32_t(0) } - var result C.int32_t = 0 - if exists { - result = 1 - } - - return result + return C.int32_t(1) } //export ext_misc_print_hex_version_1 diff --git a/lib/runtime/wasmer/imports_test.go b/lib/runtime/wasmer/imports_test.go index 3e28ccd062..747f007fff 100644 --- a/lib/runtime/wasmer/imports_test.go +++ b/lib/runtime/wasmer/imports_test.go @@ -22,6 +22,7 @@ import ( "github.com/ChainSafe/gossamer/lib/runtime" "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/lib/trie" + "github.com/ChainSafe/gossamer/lib/trie/proof" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1796,7 +1797,7 @@ func Test_ext_trie_blake2_256_verify_proof_version_1(t *testing.T) { root := hash.ToBytes() otherRoot := otherHash.ToBytes() - proof, err := trie.GenerateProof(root, keys, memdb) + allProofs, err := proof.Generate(root, keys, memdb) require.NoError(t, err) testcases := map[string]struct { @@ -1805,17 +1806,17 @@ func Test_ext_trie_blake2_256_verify_proof_version_1(t *testing.T) { expect bool }{ "Proof should be true": { - root: root, key: []byte("do"), proof: proof, value: []byte("verb"), expect: true}, + root: root, key: []byte("do"), proof: allProofs, value: []byte("verb"), expect: true}, "Root empty, proof should be false": { - root: []byte{}, key: []byte("do"), proof: proof, value: []byte("verb"), expect: false}, + root: []byte{}, key: []byte("do"), proof: allProofs, value: []byte("verb"), expect: false}, "Other root, proof should be false": { - root: otherRoot, key: []byte("do"), proof: proof, value: []byte("verb"), expect: false}, + root: otherRoot, key: []byte("do"), proof: allProofs, value: []byte("verb"), expect: false}, "Value empty, proof should be true": { - root: root, key: []byte("do"), proof: proof, value: nil, expect: true}, + root: root, key: []byte("do"), proof: allProofs, value: nil, expect: true}, "Unknow key, proof should be false": { - root: root, key: []byte("unknow"), proof: proof, value: nil, expect: false}, + root: root, key: []byte("unknow"), proof: allProofs, value: nil, expect: false}, "Key and value unknow, proof should be false": { - root: root, key: []byte("unknow"), proof: proof, value: []byte("unknow"), expect: false}, + root: root, key: []byte("unknow"), proof: allProofs, value: []byte("unknow"), expect: false}, "Empty proof, should be false": { root: root, key: []byte("do"), proof: [][]byte{}, value: nil, expect: false}, } diff --git a/lib/trie/database.go b/lib/trie/database.go index 6c300a8b04..ecd5ee2ceb 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -5,7 +5,6 @@ package trie import ( "bytes" - "errors" "fmt" "github.com/ChainSafe/gossamer/internal/trie/codec" @@ -15,10 +14,11 @@ import ( "github.com/ChainSafe/chaindb" ) -var ( - ErrEmptyProof = errors.New("proof slice empty") - ErrDecodeNode = errors.New("cannot decode node") -) +// Database is an interface to get values from a +// key value database. +type Database interface { + Get(key []byte) (value []byte, err error) +} // Store stores each trie node in the database, // where the key is the hash of the encoded node @@ -76,74 +76,9 @@ func (t *Trie) store(db chaindb.Batch, n *Node) error { return nil } -// LoadFromProof sets a partial trie based on the proof slice of encoded nodes. -// Note this is exported because it is imported is used by: -// https://github.com/ComposableFi/ibc-go/blob/6d62edaa1a3cb0768c430dab81bb195e0b0c72db/modules/light-clients/11-beefy/types/client_state.go#L78 -func (t *Trie) LoadFromProof(proofEncodedNodes [][]byte, rootHash []byte) error { - if len(proofEncodedNodes) == 0 { - return ErrEmptyProof - } - - proofHashToNode := make(map[string]*Node, len(proofEncodedNodes)) - - for i, rawNode := range proofEncodedNodes { - decodedNode, err := node.Decode(bytes.NewReader(rawNode)) - if err != nil { - return fmt.Errorf("%w: at index %d: 0x%x", - ErrDecodeNode, i, rawNode) - } - - const dirty = false - decodedNode.SetDirty(dirty) - decodedNode.Encoding = rawNode - decodedNode.HashDigest = nil - - _, hash, err := decodedNode.EncodeAndHash(false) - if err != nil { - return fmt.Errorf("cannot encode and hash node at index %d: %w", i, err) - } - - proofHash := common.BytesToHex(hash) - proofHashToNode[proofHash] = decodedNode - - if bytes.Equal(hash, rootHash) { - // Found root in proof - t.root = decodedNode - } - } - - t.loadProof(proofHashToNode, t.root) - - return nil -} - -// loadProof is a recursive function that will create all the trie paths based -// on the mapped proofs slice starting at the root -func (t *Trie) loadProof(proofHashToNode map[string]*Node, n *Node) { - if n.Type() != node.Branch { - return - } - - branch := n - for i, child := range branch.Children { - if child == nil { - continue - } - - proofHash := common.BytesToHex(child.HashDigest) - node, ok := proofHashToNode[proofHash] - if !ok { - continue - } - - branch.Children[i] = node - t.loadProof(proofHashToNode, node) - } -} - // Load reconstructs the trie from the database from the given root hash. // It is used when restarting the node to load the current state trie. -func (t *Trie) Load(db chaindb.Database, rootHash common.Hash) error { +func (t *Trie) Load(db Database, rootHash common.Hash) error { if rootHash == EmptyHash { t.root = nil return nil @@ -169,7 +104,7 @@ func (t *Trie) Load(db chaindb.Database, rootHash common.Hash) error { return t.load(db, t.root) } -func (t *Trie) load(db chaindb.Database, n *Node) error { +func (t *Trie) load(db Database, n *Node) error { if n.Type() != node.Branch { return nil } diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go deleted file mode 100644 index 4c0a169936..0000000000 --- a/lib/trie/lookup.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - - "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/internal/trie/record" -) - -var _ recorder = (*record.Recorder)(nil) - -type recorder interface { - Record(hash, rawData []byte) -} - -// findAndRecord search for a desired key recording all the nodes in the path including the desired node -func findAndRecord(t *Trie, key []byte, recorder recorder) error { - return find(t.root, key, recorder, true) -} - -func find(parent *Node, key []byte, recorder recorder, isCurrentRoot bool) error { - enc, hash, err := parent.EncodeAndHash(isCurrentRoot) - if err != nil { - return err - } - - recorder.Record(hash, enc) - - if parent.Type() != node.Branch { - return nil - } - - branch := parent - length := lenCommonPrefix(branch.Key, key) - - // found the value at this node - if bytes.Equal(branch.Key, key) || len(key) == 0 { - return nil - } - - // did not find value - if bytes.Equal(branch.Key[:length], key) && len(key) < len(branch.Key) { - return nil - } - - return find(branch.Children[key[length]], key[length+1:], recorder, false) -} diff --git a/lib/trie/proof.go b/lib/trie/proof.go deleted file mode 100644 index 2d8444d2db..0000000000 --- a/lib/trie/proof.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - - "github.com/ChainSafe/chaindb" - "github.com/ChainSafe/gossamer/internal/trie/codec" - "github.com/ChainSafe/gossamer/internal/trie/record" - "github.com/ChainSafe/gossamer/lib/common" -) - -var ( - // ErrEmptyTrieRoot ... - ErrEmptyTrieRoot = errors.New("provided trie must have a root") - - // ErrValueNotFound ... - ErrValueNotFound = errors.New("expected value not found in the trie") - - // ErrKeyNotFound ... - ErrKeyNotFound = errors.New("expected key not found in the trie") - - // ErrDuplicateKeys ... - ErrDuplicateKeys = errors.New("duplicate keys on verify proof") - - // ErrLoadFromProof ... - ErrLoadFromProof = errors.New("failed to build the proof trie") -) - -// GenerateProof receive the keys to proof, the trie root and a reference to database -func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, error) { - trackedProofs := make(map[string][]byte) - - proofTrie := NewEmptyTrie() - if err := proofTrie.Load(db, common.BytesToHash(root)); err != nil { - return nil, err - } - - for _, k := range keys { - nk := codec.KeyLEToNibbles(k) - - recorder := record.NewRecorder() - err := findAndRecord(proofTrie, nk, recorder) - if err != nil { - return nil, err - } - - for _, recNode := range recorder.GetNodes() { - nodeHashHex := common.BytesToHex(recNode.Hash) - if _, ok := trackedProofs[nodeHashHex]; !ok { - trackedProofs[nodeHashHex] = recNode.RawData - } - } - } - - proofs := make([][]byte, 0) - for _, p := range trackedProofs { - proofs = append(proofs, p) - } - - return proofs, nil -} - -// Pair holds the key and value to check while verifying the proof -type Pair struct{ Key, Value []byte } - -// VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice -// this function ignores the order of proofs -func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { - set := make(map[string]struct{}, len(items)) - - // check for duplicate keys - for _, item := range items { - hexKey := hex.EncodeToString(item.Key) - if _, ok := set[hexKey]; ok { - return false, ErrDuplicateKeys - } - set[hexKey] = struct{}{} - } - - proofTrie := NewEmptyTrie() - if err := proofTrie.LoadFromProof(proof, root); err != nil { - return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) - } - - for _, item := range items { - recValue := proofTrie.Get(item.Key) - if recValue == nil { - return false, ErrKeyNotFound - } - // here we need to compare value only if the caller pass the value - if len(item.Value) > 0 && !bytes.Equal(item.Value, recValue) { - return false, ErrValueNotFound - } - } - - return true, nil -} diff --git a/lib/trie/proof/generate.go b/lib/trie/proof/generate.go new file mode 100644 index 0000000000..9ab7c1dc84 --- /dev/null +++ b/lib/trie/proof/generate.go @@ -0,0 +1,139 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "bytes" + "errors" + "fmt" + + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" +) + +var ( + ErrKeyNotFound = errors.New("key not found") +) + +// Database defines a key value Get method used +// for proof generation. +type Database interface { + Get(key []byte) (value []byte, err error) +} + +// Generate generates and deduplicates the encoded proof nodes +// for the trie corresponding to the root hash given, and for +// the slice of (Little Endian) full keys given. The database given +// is used to load the trie using the root hash given. +func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( + encodedProofNodes [][]byte, err error) { + trie := trie.NewEmptyTrie() + if err := trie.Load(database, common.BytesToHash(rootHash)); err != nil { + return nil, fmt.Errorf("loading trie: %w", err) + } + rootNode := trie.RootNode() + + hashesSeen := make(map[string]struct{}) + for _, fullKey := range fullKeys { + fullKeyNibbles := codec.KeyLEToNibbles(fullKey) + const isRoot = true + newEncodedProofNodes, err := walk(rootNode, fullKeyNibbles, isRoot) + if err != nil { + // Note we wrap the full key context here since walk is recursive and + // may not be aware of the initial full key. + return nil, fmt.Errorf("walking to node at key 0x%x: %w", fullKey, err) + } + + for _, encodedProofNode := range newEncodedProofNodes { + digest, err := common.Blake2bHash(encodedProofNode) + if err != nil { + return nil, fmt.Errorf("blake2b hash: %w", err) + } + hashString := string(digest.ToBytes()) + + _, seen := hashesSeen[hashString] + if seen { + continue + } + hashesSeen[hashString] = struct{}{} + + encodedProofNodes = append(encodedProofNodes, encodedProofNode) + } + } + + return encodedProofNodes, nil +} + +func walk(parent *node.Node, fullKey []byte, isRoot bool) ( + encodedProofNodes [][]byte, err error) { + if parent == nil { + if len(fullKey) == 0 { + return nil, nil + } + return nil, ErrKeyNotFound + } + + // Note we do not use sync.Pool buffers since we would have + // to copy it so it persists in encodedProofNodes. + encodingBuffer := bytes.NewBuffer(nil) + err = parent.Encode(encodingBuffer) + if err != nil { + return nil, fmt.Errorf("encode node: %w", err) + } + + if isRoot || encodingBuffer.Len() >= 32 { + // Only add the root node encoding (whatever its length) + // and child node encodings greater or equal to 32 bytes. + // This is because child node encodings of less than 32 bytes + // are inlined in the parent node encoding, so there is no need + // to duplicate them in the proof generated. + encodedProofNodes = append(encodedProofNodes, encodingBuffer.Bytes()) + } + + nodeFound := len(fullKey) == 0 || bytes.Equal(parent.Key, fullKey) + if nodeFound { + return encodedProofNodes, nil + } + + if parent.Type() == node.Leaf && !nodeFound { + return nil, ErrKeyNotFound + } + + nodeIsDeeper := len(fullKey) > len(parent.Key) + if !nodeIsDeeper { + return nil, ErrKeyNotFound + } + + commonLength := lenCommonPrefix(parent.Key, fullKey) + childIndex := fullKey[commonLength] + nextChild := parent.Children[childIndex] + nextFullKey := fullKey[commonLength+1:] + isRoot = false + deeperEncodedProofNodes, err := walk(nextChild, nextFullKey, isRoot) + if err != nil { + return nil, err // note: do not wrap since this is recursive + } + + encodedProofNodes = append(encodedProofNodes, deeperEncodedProofNodes...) + return encodedProofNodes, nil +} + +// lenCommonPrefix returns the length of the +// common prefix between two byte slices. +func lenCommonPrefix(a, b []byte) (length int) { + min := len(a) + if len(b) < min { + min = len(b) + } + + for length = 0; length < min; length++ { + if a[length] != b[length] { + break + } + } + + return length +} diff --git a/lib/trie/proof/generate_test.go b/lib/trie/proof/generate_test.go new file mode 100644 index 0000000000..9eaffde1bf --- /dev/null +++ b/lib/trie/proof/generate_test.go @@ -0,0 +1,613 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "errors" + "testing" + + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/trie" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Generate(t *testing.T) { + t.Parallel() + + errTest := errors.New("test error") + + someHash := make([]byte, 32) + for i := range someHash { + someHash[i] = byte(i) + } + + largeValue := generateBytes(t, 40) + assertLongEncoding(t, node.Node{Value: largeValue}) + + testCases := map[string]struct { + rootHash []byte + fullKeysNibbles [][]byte + databaseBuilder func(ctrl *gomock.Controller) Database + encodedProofNodes [][]byte + errWrapped error + errMessage string + }{ + "failed loading trie": { + rootHash: someHash, + databaseBuilder: func(ctrl *gomock.Controller) Database { + mockDatabase := NewMockDatabase(ctrl) + mockDatabase.EXPECT().Get(someHash). + Return(nil, errTest) + return mockDatabase + }, + errWrapped: errTest, + errMessage: "loading trie: " + + "failed to find root key " + + "0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f: " + + "test error", + }, + "walk error": { + rootHash: someHash, + fullKeysNibbles: [][]byte{{1}}, + databaseBuilder: func(ctrl *gomock.Controller) Database { + mockDatabase := NewMockDatabase(ctrl) + encodedRoot := encodeNode(t, node.Node{ + Key: []byte{1}, + Value: []byte{2}, + }) + mockDatabase.EXPECT().Get(someHash). + Return(encodedRoot, nil) + return mockDatabase + }, + errWrapped: ErrKeyNotFound, + errMessage: "walking to node at key 0x01: key not found", + }, + "leaf root": { + rootHash: someHash, + fullKeysNibbles: [][]byte{{}}, + databaseBuilder: func(ctrl *gomock.Controller) Database { + mockDatabase := NewMockDatabase(ctrl) + encodedRoot := encodeNode(t, node.Node{ + Key: []byte{1}, + Value: []byte{2}, + }) + mockDatabase.EXPECT().Get(someHash). + Return(encodedRoot, nil) + return mockDatabase + }, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1}, + Value: []byte{2}, + }), + }, + }, + "branch root": { + rootHash: someHash, + fullKeysNibbles: [][]byte{{}}, + databaseBuilder: func(ctrl *gomock.Controller) Database { + mockDatabase := NewMockDatabase(ctrl) + encodedRoot := encodeNode(t, node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Children: padRightChildren([]*node.Node{ + nil, nil, + { + Key: []byte{3}, + Value: []byte{4}, + }, + }), + }) + mockDatabase.EXPECT().Get(someHash). + Return(encodedRoot, nil) + return mockDatabase + }, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Children: padRightChildren([]*node.Node{ + nil, nil, + { + Key: []byte{3}, + Value: []byte{4}, + }, + }), + }), + }, + }, + "target leaf of branch": { + rootHash: someHash, + fullKeysNibbles: [][]byte{ + {1, 2, 3, 4}, + }, + databaseBuilder: func(ctrl *gomock.Controller) Database { + mockDatabase := NewMockDatabase(ctrl) + + rootNode := node.Node{ + Key: []byte{1, 2}, + Value: []byte{2}, + Children: padRightChildren([]*node.Node{ + nil, nil, nil, + { // full key 1, 2, 3, 4 + Key: []byte{4}, + Value: largeValue, + }, + }), + } + + mockDatabase.EXPECT().Get(someHash). + Return(encodeNode(t, rootNode), nil) + + encodedChild := encodeNode(t, *rootNode.Children[3]) + mockDatabase.EXPECT().Get(blake2b(t, encodedChild)). + Return(encodedChild, nil) + + return mockDatabase + }, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{2}, + Children: padRightChildren([]*node.Node{ + nil, nil, nil, + { + Key: []byte{4}, + Value: largeValue, + }, + }), + }), + encodeNode(t, node.Node{ + Key: []byte{4}, + Value: largeValue, + }), + }, + }, + "deduplicate proof nodes": { + rootHash: someHash, + fullKeysNibbles: [][]byte{ + {1, 2, 3, 4}, + {1, 2, 4, 4}, + {1, 2, 5, 5}, + }, + databaseBuilder: func(ctrl *gomock.Controller) Database { + mockDatabase := NewMockDatabase(ctrl) + + rootNode := node.Node{ + Key: []byte{1, 2}, + Value: []byte{2}, + Children: padRightChildren([]*node.Node{ + nil, nil, nil, + { // full key 1, 2, 3, 4 + Key: []byte{4}, + Value: largeValue, + }, + { // full key 1, 2, 4, 4 + Key: []byte{4}, + Value: largeValue, + }, + { // full key 1, 2, 5, 5 + Key: []byte{5}, + Value: largeValue, + }, + }), + } + + mockDatabase.EXPECT().Get(someHash). + Return(encodeNode(t, rootNode), nil) + + encodedLargeChild1 := encodeNode(t, *rootNode.Children[3]) + mockDatabase.EXPECT().Get(blake2b(t, encodedLargeChild1)). + Return(encodedLargeChild1, nil).Times(2) + + encodedLargeChild2 := encodeNode(t, *rootNode.Children[5]) + mockDatabase.EXPECT().Get(blake2b(t, encodedLargeChild2)). + Return(encodedLargeChild2, nil) + + return mockDatabase + }, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{2}, + Children: padRightChildren([]*node.Node{ + nil, nil, nil, + { // full key 1, 2, 3, 4 + Key: []byte{4}, + Value: largeValue, + }, + { // full key 1, 2, 4, 4 + Key: []byte{4}, + Value: largeValue, + }, + { // full key 1, 2, 5, 5 + Key: []byte{5}, + Value: largeValue, + }, + }), + }), + encodeNode(t, node.Node{ + Key: []byte{4}, + Value: largeValue, + }), + encodeNode(t, node.Node{ + Key: []byte{5}, + Value: largeValue, + }), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + database := testCase.databaseBuilder(ctrl) + fullKeysLE := make([][]byte, len(testCase.fullKeysNibbles)) + for i, fullKeyNibbles := range testCase.fullKeysNibbles { + fullKeysLE[i] = codec.NibblesToKeyLE(fullKeyNibbles) + } + + encodedProofNodes, err := Generate(testCase.rootHash, + fullKeysLE, database) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encodedProofNodes, encodedProofNodes) + }) + } +} + +func Test_walk(t *testing.T) { + t.Parallel() + + largeValue := generateBytes(t, 40) + assertLongEncoding(t, node.Node{Value: largeValue}) + + testCases := map[string]struct { + parent *node.Node + fullKey []byte // nibbles + isRoot bool + encodedProofNodes [][]byte + errWrapped error + errMessage string + }{ + "nil parent and empty full key": {}, + "nil parent and non empty full key": { + fullKey: []byte{1}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + // The parent encode error cannot be triggered here + // since it can only be caused by a buffer.Write error. + "parent leaf and empty full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + isRoot: true, + encodedProofNodes: [][]byte{encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + })}, + }, + "parent leaf and shorter full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + fullKey: []byte{1}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "parent leaf and mismatching full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + fullKey: []byte{1, 3}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "parent leaf and longer full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + fullKey: []byte{1, 2, 3}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "branch and empty search key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + isRoot: true, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }), + }, + }, + "branch and shorter full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "branch and mismatching full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1, 3}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "branch and matching search key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1, 2}, + isRoot: true, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }), + }, + }, + "branch and matching search key for small leaf encoding": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: []byte{3}, + }, + }), + }, + fullKey: []byte{1, 2, 0, 1, 2}, + isRoot: true, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: []byte{3}, + }, + }), + }), + // Note the leaf encoding is not added since its encoding + // is less than 32 bytes. + }, + }, + "branch and matching search key for large leaf encoding": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: largeValue, + }, + }), + }, + fullKey: []byte{1, 2, 0, 1, 2}, + isRoot: true, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: largeValue, + }, + }), + }), + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + }), + }, + }, + "key not found at deeper level": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4, 5}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1, 2, 0x04, 4}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "found leaf at deeper level": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1, 2, 0x04}, + isRoot: true, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encodedProofNodes, err := walk(testCase.parent, testCase.fullKey, testCase.isRoot) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encodedProofNodes, encodedProofNodes) + }) + } +} + +func Test_lenCommonPrefix(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + a []byte + b []byte + length int + }{ + "nil slices": {}, + "empty slices": { + a: []byte{}, + b: []byte{}, + }, + "fully different": { + a: []byte{1, 2, 3}, + b: []byte{4, 5, 6}, + }, + "fully same": { + a: []byte{1, 2, 3}, + b: []byte{1, 2, 3}, + length: 3, + }, + "different and common prefix": { + a: []byte{1, 2, 3, 4}, + b: []byte{1, 2, 4, 4}, + length: 2, + }, + "first bigger than second": { + a: []byte{1, 2, 3}, + b: []byte{1, 2}, + length: 2, + }, + "first smaller than second": { + a: []byte{1, 2}, + b: []byte{1, 2, 3}, + length: 2, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + length := lenCommonPrefix(testCase.a, testCase.b) + + assert.Equal(t, testCase.length, length) + }) + } +} + +// Note on the performance of walk: +// It was tried to optimise appending to the encodedProofNodes +// slice by: +// 1. appending to the same slice *[][]byte passed as argument +// 2. appending the upper node to the deeper nodes slice +// In both cases, the performance difference is very small +// so the code is kept to this inefficient-looking append, +// which is in the end quite performant still. +func Benchmark_walk(b *testing.B) { + trie := trie.NewEmptyTrie() + + // Build a deep trie. + const trieDepth = 1000 + for i := 0; i < trieDepth; i++ { + keySize := 1 + i + key := make([]byte, keySize) + const trieValueSize = 10 + value := make([]byte, trieValueSize) + + trie.Put(key, value) + } + + longestKeyLE := make([]byte, trieDepth) + longestKeyNibbles := codec.KeyLEToNibbles(longestKeyLE) + + rootNode := trie.RootNode() + const isRoot = true + encodedProofNodes, err := walk(rootNode, longestKeyNibbles, isRoot) + require.NoError(b, err) + require.Equal(b, len(encodedProofNodes), trieDepth) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = walk(rootNode, longestKeyNibbles, isRoot) + } +} diff --git a/lib/trie/proof/helpers_test.go b/lib/trie/proof/helpers_test.go new file mode 100644 index 0000000000..f7279f75e2 --- /dev/null +++ b/lib/trie/proof/helpers_test.go @@ -0,0 +1,102 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "bytes" + "math/rand" + "testing" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/require" +) + +func padRightChildren(slice []*node.Node) (paddedSlice []*node.Node) { + paddedSlice = make([]*node.Node, node.ChildrenCapacity) + copy(paddedSlice, slice) + return paddedSlice +} + +func encodeNode(t *testing.T, node node.Node) (encoded []byte) { + t.Helper() + buffer := bytes.NewBuffer(nil) + err := node.Encode(buffer) + require.NoError(t, err) + return buffer.Bytes() +} + +func blake2bNode(t *testing.T, node node.Node) (digest []byte) { + t.Helper() + encoding := encodeNode(t, node) + return blake2b(t, encoding) +} + +func scaleEncode(t *testing.T, data []byte) (encoded []byte) { + t.Helper() + encoded, err := scale.Marshal(data) + require.NoError(t, err) + return encoded +} + +func blake2b(t *testing.T, data []byte) (digest []byte) { + t.Helper() + digestHash, err := common.Blake2bHash(data) + require.NoError(t, err) + digest = digestHash[:] + return digest +} + +func concatBytes(slices [][]byte) (concatenated []byte) { + for _, slice := range slices { + concatenated = append(concatenated, slice...) + } + return concatenated +} + +// generateBytes generates a pseudo random byte slice +// of the given length. It uses `0` as its seed so +// calling it multiple times will generate the same +// byte slice. This is designed as such in order to have +// deterministic unit tests. +func generateBytes(t *testing.T, length uint) (bytes []byte) { + t.Helper() + generator := rand.New(rand.NewSource(0)) + bytes = make([]byte, length) + _, err := generator.Read(bytes) + require.NoError(t, err) + return bytes +} + +// getBadNodeEncoding returns a particular bad node encoding of 33 bytes. +func getBadNodeEncoding() (badEncoding []byte) { + return []byte{ + 0x1, 0x94, 0xfd, 0xc2, 0xfa, 0x2f, 0xfc, 0xc0, 0x41, 0xd3, + 0xff, 0x12, 0x4, 0x5b, 0x73, 0xc8, 0x6e, 0x4f, 0xf9, 0x5f, + 0xf6, 0x62, 0xa5, 0xee, 0xe8, 0x2a, 0xbd, 0xf4, 0x4a, 0x2d, + 0xb, 0x75, 0xfb} +} + +func Test_getBadNodeEncoding(t *testing.T) { + t.Parallel() + + badEncoding := getBadNodeEncoding() + _, err := node.Decode(bytes.NewBuffer(badEncoding)) + require.Error(t, err) +} + +func assertLongEncoding(t *testing.T, node node.Node) { + t.Helper() + + encoding := encodeNode(t, node) + require.Greater(t, len(encoding), 32) +} + +func assertShortEncoding(t *testing.T, node node.Node) { + t.Helper() + + encoding := encodeNode(t, node) + require.LessOrEqual(t, len(encoding), 32) +} diff --git a/lib/trie/proof/mocks_generate_test.go b/lib/trie/proof/mocks_generate_test.go new file mode 100644 index 0000000000..81a9e78f9b --- /dev/null +++ b/lib/trie/proof/mocks_generate_test.go @@ -0,0 +1,6 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Database diff --git a/lib/trie/proof/mocks_test.go b/lib/trie/proof/mocks_test.go new file mode 100644 index 0000000000..69262dc315 --- /dev/null +++ b/lib/trie/proof/mocks_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/lib/trie/proof (interfaces: Database) + +// Package proof is a generated GoMock package. +package proof + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockDatabase is a mock of Database interface. +type MockDatabase struct { + ctrl *gomock.Controller + recorder *MockDatabaseMockRecorder +} + +// MockDatabaseMockRecorder is the mock recorder for MockDatabase. +type MockDatabaseMockRecorder struct { + mock *MockDatabase +} + +// NewMockDatabase creates a new mock instance. +func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase { + mock := &MockDatabase{ctrl: ctrl} + mock.recorder = &MockDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockDatabase) Get(arg0 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockDatabaseMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDatabase)(nil).Get), arg0) +} diff --git a/lib/trie/proof/proof_test.go b/lib/trie/proof/proof_test.go new file mode 100644 index 0000000000..7a403553fb --- /dev/null +++ b/lib/trie/proof/proof_test.go @@ -0,0 +1,52 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "fmt" + "testing" + + "github.com/ChainSafe/chaindb" + "github.com/ChainSafe/gossamer/lib/trie" + "github.com/stretchr/testify/require" +) + +func Test_Generate_Verify(t *testing.T) { + t.Parallel() + + keys := []string{ + "cat", + "catapulta", + "catapora", + "dog", + "doguinho", + } + + trie := trie.NewEmptyTrie() + + for i, key := range keys { + value := fmt.Sprintf("%x-%d", key, i) + trie.Put([]byte(key), []byte(value)) + } + + rootHash, err := trie.Hash() + require.NoError(t, err) + + database, err := chaindb.NewBadgerDB(&chaindb.Config{ + InMemory: true, + }) + require.NoError(t, err) + err = trie.Store(database) + require.NoError(t, err) + + for i, key := range keys { + fullKeys := [][]byte{[]byte(key)} + proof, err := Generate(rootHash.ToBytes(), fullKeys, database) + require.NoError(t, err) + + expectedValue := fmt.Sprintf("%x-%d", key, i) + err = Verify(proof, rootHash.ToBytes(), []byte(key), []byte(expectedValue)) + require.NoError(t, err) + } +} diff --git a/lib/trie/proof/verify.go b/lib/trie/proof/verify.go new file mode 100644 index 0000000000..b98b15476d --- /dev/null +++ b/lib/trie/proof/verify.go @@ -0,0 +1,177 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "bytes" + "errors" + "fmt" + "strings" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" +) + +var ( + ErrKeyNotFoundInProofTrie = errors.New("key not found in proof trie") + ErrValueMismatchProofTrie = errors.New("value found in proof trie does not match") +) + +// Verify verifies a given key and value belongs to the trie by creating +// a proof trie based on the encoded proof nodes given. The order of proofs is ignored. +// A nil error is returned on success. +// Note this is exported because it is imported and used by: +// https://github.com/ComposableFi/ibc-go/blob/6d62edaa1a3cb0768c430dab81bb195e0b0c72db/modules/light-clients/11-beefy/types/client_state.go#L78 +func Verify(encodedProofNodes [][]byte, rootHash, key, value []byte) (err error) { + proofTrie, err := buildTrie(encodedProofNodes, rootHash) + if err != nil { + return fmt.Errorf("building trie from proof encoded nodes: %w", err) + } + + proofTrieValue := proofTrie.Get(key) + if proofTrieValue == nil { + return fmt.Errorf("%w: %s in proof trie for root hash 0x%x", + ErrKeyNotFoundInProofTrie, bytesToString(key), rootHash) + } + + // compare the value only if the caller pass a non empty value + if len(value) > 0 && !bytes.Equal(value, proofTrieValue) { + return fmt.Errorf("%w: expected value %s but got value %s from proof trie", + ErrValueMismatchProofTrie, bytesToString(value), bytesToString(proofTrieValue)) + } + + return nil +} + +var ( + ErrEmptyProof = errors.New("proof slice empty") + ErrRootNodeNotFound = errors.New("root node not found in proof") +) + +// buildTrie sets a partial trie based on the proof slice of encoded nodes. +func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err error) { + if len(encodedProofNodes) == 0 { + return nil, fmt.Errorf("%w: for Merkle root hash 0x%x", + ErrEmptyProof, rootHash) + } + + merkleValueToEncoding := make(map[string][]byte, len(encodedProofNodes)) + + // This loop finds the root node and decodes it. + // The other nodes have their Merkle value (blake2b digest or the encoding itself) + // inserted into a map from merkle value to encoding. They are only decoded + // later if the root or one of its descendant node reference their Merkle value. + var root *node.Node + for _, encodedProofNode := range encodedProofNodes { + var digest []byte + if root == nil { + // root node not found yet + digestHash, err := common.Blake2bHash(encodedProofNode) + if err != nil { + return nil, fmt.Errorf("blake2b hash: %w", err) + } + digest = digestHash[:] + + if bytes.Equal(digest, rootHash) { + root, err = node.Decode(bytes.NewReader(encodedProofNode)) + if err != nil { + return nil, fmt.Errorf("decoding root node: %w", err) + } + continue // no need to add root to map of hash to encoding + } + } + + var merkleValue []byte + if len(encodedProofNode) <= 32 { + merkleValue = encodedProofNode + } else { + if digest == nil { + digestHash, err := common.Blake2bHash(encodedProofNode) + if err != nil { + return nil, fmt.Errorf("blake2b hash: %w", err) + } + digest = digestHash[:] + } + merkleValue = digest + } + + merkleValueToEncoding[string(merkleValue)] = encodedProofNode + } + + if root == nil { + proofMerkleValues := make([]string, 0, len(merkleValueToEncoding)) + for merkleValueString := range merkleValueToEncoding { + merkleValueHex := common.BytesToHex([]byte(merkleValueString)) + proofMerkleValues = append(proofMerkleValues, merkleValueHex) + } + return nil, fmt.Errorf("%w: for Merkle root hash 0x%x in proof Merkle value(s) %s", + ErrRootNodeNotFound, rootHash, strings.Join(proofMerkleValues, ", ")) + } + + err = loadProof(merkleValueToEncoding, root) + if err != nil { + return nil, fmt.Errorf("loading proof: %w", err) + } + + return trie.NewTrie(root), nil +} + +// loadProof is a recursive function that will create all the trie paths based +// on the map from node hash to node starting at the root. +func loadProof(merkleValueToEncoding map[string][]byte, n *node.Node) (err error) { + if n.Type() != node.Branch { + return nil + } + + branch := n + for i, child := range branch.Children { + if child == nil { + continue + } + + merkleValue := child.HashDigest + encoding, ok := merkleValueToEncoding[string(merkleValue)] + if !ok { + inlinedChild := len(child.Value) > 0 || child.HasChild() + if !inlinedChild { + // hash not found and the child is not inlined, + // so clear the child from the branch. + branch.Descendants -= 1 + child.Descendants + branch.Children[i] = nil + if !branch.HasChild() { + // Convert branch to a leaf if all its children are nil. + branch.Children = nil + } + } + continue + } + + child, err := node.Decode(bytes.NewReader(encoding)) + if err != nil { + return fmt.Errorf("decoding child node for Merkle value 0x%x: %w", + merkleValue, err) + } + + branch.Children[i] = child + branch.Descendants += child.Descendants + err = loadProof(merkleValueToEncoding, child) + if err != nil { + return err // do not wrap error since this is recursive + } + } + + return nil +} + +func bytesToString(b []byte) (s string) { + switch { + case b == nil: + return "nil" + case len(b) <= 20: + return fmt.Sprintf("0x%x", b) + default: + return fmt.Sprintf("0x%x...%x", b[:8], b[len(b)-8:]) + } +} diff --git a/lib/trie/proof/verify_test.go b/lib/trie/proof/verify_test.go new file mode 100644 index 0000000000..db3b346a1b --- /dev/null +++ b/lib/trie/proof/verify_test.go @@ -0,0 +1,628 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "testing" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/trie" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Verify(t *testing.T) { + t.Parallel() + + leafA := node.Node{ + Key: []byte{1}, + Value: []byte{1}, + } + + // leafB is a leaf encoding to more than 32 bytes + leafB := node.Node{ + Key: []byte{2}, + Value: generateBytes(t, 40), + } + assertLongEncoding(t, leafB) + + branch := node.Node{ + Key: []byte{3, 4}, + Value: []byte{1}, + Children: padRightChildren([]*node.Node{ + &leafB, + nil, + &leafA, + &leafB, + }), + } + assertLongEncoding(t, branch) + + testCases := map[string]struct { + encodedProofNodes [][]byte + rootHash []byte + keyLE []byte + value []byte + errWrapped error + errMessage string + }{ + "failed building proof trie": { + rootHash: []byte{1, 2, 3}, + errWrapped: ErrEmptyProof, + errMessage: "building trie from proof encoded nodes: " + + "proof slice empty: for Merkle root hash 0x010203", + }, + "value not found": { + encodedProofNodes: [][]byte{ + encodeNode(t, branch), + encodeNode(t, leafB), + // Note leaf A is small enough to be inlined in branch + }, + rootHash: blake2bNode(t, branch), + keyLE: []byte{1, 1}, // nil child of branch + errWrapped: ErrKeyNotFoundInProofTrie, + errMessage: "key not found in proof trie: " + + "0x0101 in proof trie for root hash " + + "0xec4bb0acfcf778ae8746d3ac3325fc73c3d9b376eb5f8d638dbf5eb462f5e703", + }, + "key found with nil search value": { + encodedProofNodes: [][]byte{ + encodeNode(t, branch), + encodeNode(t, leafB), + // Note leaf A is small enough to be inlined in branch + }, + rootHash: blake2bNode(t, branch), + keyLE: []byte{0x34, 0x21}, // inlined short leaf of branch + }, + "key found with mismatching value": { + encodedProofNodes: [][]byte{ + encodeNode(t, branch), + encodeNode(t, leafB), + // Note leaf A is small enough to be inlined in branch + }, + rootHash: blake2bNode(t, branch), + keyLE: []byte{0x34, 0x21}, // inlined short leaf of branch + value: []byte{2}, + errWrapped: ErrValueMismatchProofTrie, + errMessage: "value found in proof trie does not match: " + + "expected value 0x02 but got value 0x01 from proof trie", + }, + "key found with matching value": { + encodedProofNodes: [][]byte{ + encodeNode(t, branch), + encodeNode(t, leafB), + // Note leaf A is small enough to be inlined in branch + }, + rootHash: blake2bNode(t, branch), + keyLE: []byte{0x34, 0x32}, // large hash-referenced leaf of branch + value: generateBytes(t, 40), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := Verify(testCase.encodedProofNodes, testCase.rootHash, testCase.keyLE, testCase.value) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_buildTrie(t *testing.T) { + t.Parallel() + + leafAShort := node.Node{ + Key: []byte{1}, + Value: []byte{2}, + } + assertShortEncoding(t, leafAShort) + + leafBLarge := node.Node{ + Key: []byte{2}, + Value: generateBytes(t, 40), + } + assertLongEncoding(t, leafBLarge) + + leafCLarge := node.Node{ + Key: []byte{3}, + Value: generateBytes(t, 40), + } + assertLongEncoding(t, leafCLarge) + + testCases := map[string]struct { + encodedProofNodes [][]byte + rootHash []byte + expectedTrie *trie.Trie + errWrapped error + errMessage string + }{ + "no proof node": { + errWrapped: ErrEmptyProof, + rootHash: []byte{1}, + errMessage: "proof slice empty: for Merkle root hash 0x01", + }, + "root node decoding error": { + encodedProofNodes: [][]byte{ + getBadNodeEncoding(), + }, + rootHash: blake2b(t, getBadNodeEncoding()), + errWrapped: node.ErrVariantUnknown, + errMessage: "decoding root node: decoding header: " + + "decoding header byte: node variant is unknown: " + + "for header byte 00000001", + }, + "root proof encoding smaller than 32 bytes": { + encodedProofNodes: [][]byte{ + encodeNode(t, leafAShort), + }, + rootHash: blake2bNode(t, leafAShort), + expectedTrie: trie.NewTrie(&node.Node{ + Key: leafAShort.Key, + Value: leafAShort.Value, + Dirty: true, + }), + }, + "root proof encoding larger than 32 bytes": { + encodedProofNodes: [][]byte{ + encodeNode(t, leafBLarge), + }, + rootHash: blake2bNode(t, leafBLarge), + expectedTrie: trie.NewTrie(&node.Node{ + Key: leafBLarge.Key, + Value: leafBLarge.Value, + Dirty: true, + }), + }, + "discard unused node": { + encodedProofNodes: [][]byte{ + encodeNode(t, leafAShort), + encodeNode(t, leafBLarge), + }, + rootHash: blake2bNode(t, leafAShort), + expectedTrie: trie.NewTrie(&node.Node{ + Key: leafAShort.Key, + Value: leafAShort.Value, + Dirty: true, + }), + }, + "multiple unordered nodes": { + encodedProofNodes: [][]byte{ + encodeNode(t, leafBLarge), // chilren 1 and 3 + encodeNode(t, node.Node{ // root + Key: []byte{1}, + Children: padRightChildren([]*node.Node{ + &leafAShort, // inlined + &leafBLarge, // referenced by Merkle value hash + &leafCLarge, // referenced by Merkle value hash + &leafBLarge, // referenced by Merkle value hash + }), + }), + encodeNode(t, leafCLarge), // children 2 + }, + rootHash: blake2bNode(t, node.Node{ + Key: []byte{1}, + Children: padRightChildren([]*node.Node{ + &leafAShort, + &leafBLarge, + &leafCLarge, + &leafBLarge, + }), + }), + expectedTrie: trie.NewTrie(&node.Node{ + Key: []byte{1}, + Descendants: 4, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: leafAShort.Key, + Value: leafAShort.Value, + Dirty: true, + }, + { + Key: leafBLarge.Key, + Value: leafBLarge.Value, + Dirty: true, + }, + { + Key: leafCLarge.Key, + Value: leafCLarge.Value, + Dirty: true, + }, + { + Key: leafBLarge.Key, + Value: leafBLarge.Value, + Dirty: true, + }, + }), + }), + }, + "load proof decoding error": { + encodedProofNodes: [][]byte{ + getBadNodeEncoding(), + // root with one child pointing to hash of bad encoding above. + concatBytes([][]byte{ + {0b1000_0000 | 0b0000_0001}, // branch with key size 1 + {1}, // key + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncode(t, blake2b(t, getBadNodeEncoding())), // child hash + }), + }, + rootHash: blake2b(t, concatBytes([][]byte{ + {0b1000_0000 | 0b0000_0001}, // branch with key size 1 + {1}, // key + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncode(t, blake2b(t, getBadNodeEncoding())), // child hash + })), + errWrapped: node.ErrVariantUnknown, + errMessage: "loading proof: decoding child node for Merkle value " + + "0xcfa21f0ec11a3658d77701b7b1f52fbcb783fe3df662977b6e860252b6c37e1e: " + + "decoding header: decoding header byte: " + + "node variant is unknown: for header byte 00000001", + }, + "root not found": { + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1}, + Value: []byte{2}, + }), + }, + rootHash: []byte{3}, + errWrapped: ErrRootNodeNotFound, + errMessage: "root node not found in proof: " + + "for Merkle root hash 0x03 in proof Merkle value(s) 0x41010402", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + trie, err := buildTrie(testCase.encodedProofNodes, testCase.rootHash) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + + if testCase.expectedTrie != nil { + require.NotNil(t, trie) + require.Equal(t, testCase.expectedTrie.String(), trie.String()) + } + assert.Equal(t, testCase.expectedTrie, trie) + }) + } +} + +func Test_loadProof(t *testing.T) { + t.Parallel() + + largeValue := generateBytes(t, 40) + + leafLarge := node.Node{ + Key: []byte{3}, + Value: largeValue, + } + assertLongEncoding(t, leafLarge) + + testCases := map[string]struct { + merkleValueToEncoding map[string][]byte + node *node.Node + expectedNode *node.Node + errWrapped error + errMessage string + }{ + "leaf node": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + }, + }, + "branch node with child hash not found": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{3}}, + }), + }, + merkleValueToEncoding: map[string][]byte{}, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Dirty: true, + }, + }, + "branch node with child hash found": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, + }), + }, + merkleValueToEncoding: map[string][]byte{ + string([]byte{2}): encodeNode(t, node.Node{ + Key: []byte{3}, + Value: []byte{1}, + }), + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{3}, + Value: []byte{1}, + Dirty: true, + }, + }), + }, + }, + "branch node with one child hash found and one not found": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 2, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, // found + {HashDigest: []byte{3}}, // not found + }), + }, + merkleValueToEncoding: map[string][]byte{ + string([]byte{2}): encodeNode(t, node.Node{ + Key: []byte{3}, + Value: []byte{1}, + }), + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{3}, + Value: []byte{1}, + Dirty: true, + }, + }), + }, + }, + "branch node with branch child hash": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 2, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, + }), + }, + merkleValueToEncoding: map[string][]byte{ + string([]byte{2}): encodeNode(t, node.Node{ + Key: []byte{3}, + Value: []byte{1}, + Children: padRightChildren([]*node.Node{ + {Key: []byte{4}, Value: []byte{2}}, + }), + }), + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 3, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{3}, + Value: []byte{1}, + Dirty: true, + Descendants: 1, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{2}, + Dirty: true, + }, + }), + }, + }), + }, + }, + "child decoding error": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, + }), + }, + merkleValueToEncoding: map[string][]byte{ + string([]byte{2}): getBadNodeEncoding(), + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, + }), + }, + errWrapped: node.ErrVariantUnknown, + errMessage: "decoding child node for Merkle value 0x02: " + + "decoding header: decoding header byte: node variant is unknown: " + + "for header byte 00000001", + }, + "grand child": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{1}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, + }), + }, + merkleValueToEncoding: map[string][]byte{ + string([]byte{2}): encodeNode(t, node.Node{ + Key: []byte{2}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + &leafLarge, // encoded to hash + }), + }), + string(blake2bNode(t, leafLarge)): encodeNode(t, leafLarge), + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{1}, + Descendants: 2, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{2}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: leafLarge.Key, + Value: leafLarge.Value, + Dirty: true, + }, + }), + }, + }), + }, + }, + + "grand child load proof error": { + node: &node.Node{ + Key: []byte{1}, + Value: []byte{1}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + {HashDigest: []byte{2}}, + }), + }, + merkleValueToEncoding: map[string][]byte{ + string([]byte{2}): encodeNode(t, node.Node{ + Key: []byte{2}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + &leafLarge, // encoded to hash + }), + }), + string(blake2bNode(t, leafLarge)): getBadNodeEncoding(), + }, + expectedNode: &node.Node{ + Key: []byte{1}, + Value: []byte{1}, + Descendants: 2, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{2}, + Value: []byte{2}, + Descendants: 1, + Dirty: true, + Children: padRightChildren([]*node.Node{ + { + HashDigest: blake2bNode(t, leafLarge), + Dirty: true, + }, + }), + }, + }), + }, + errWrapped: node.ErrVariantUnknown, + errMessage: "decoding child node for Merkle value " + + "0x6888b9403129c11350c6054b46875292c0ffedcfd581e66b79bdf350b775ebf2: " + + "decoding header: decoding header byte: node variant is unknown: " + + "for header byte 00000001", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := loadProof(testCase.merkleValueToEncoding, testCase.node) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + + assert.Equal(t, testCase.expectedNode.String(), testCase.node.String()) + }) + } +} + +func Test_bytesToString(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + b []byte + s string + }{ + "nil slice": { + s: "nil", + }, + "empty slice": { + b: []byte{}, + s: "0x", + }, + "small slice": { + b: []byte{1, 2, 3}, + s: "0x010203", + }, + "big slice": { + b: []byte{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + }, + s: "0x0001020304050607...0203040506070809", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := bytesToString(testCase.b) + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go deleted file mode 100644 index 78d58a1675..0000000000 --- a/lib/trie/proof_test.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "testing" - - "github.com/ChainSafe/chaindb" - "github.com/stretchr/testify/require" -) - -func TestProofGeneration(t *testing.T) { - t.Parallel() - - tmp := t.TempDir() - - memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ - InMemory: true, - DataDir: tmp, - }) - require.NoError(t, err) - - const size = 32 - generator := newGenerator() - - expectedValue := generateRandBytes(t, size, generator) - - trie := NewEmptyTrie() - trie.Put([]byte("cat"), generateRandBytes(t, size, generator)) - trie.Put([]byte("catapulta"), generateRandBytes(t, size, generator)) - trie.Put([]byte("catapora"), expectedValue) - trie.Put([]byte("dog"), generateRandBytes(t, size, generator)) - trie.Put([]byte("doguinho"), generateRandBytes(t, size, generator)) - - err = trie.Store(memdb) - require.NoError(t, err) - - hash, err := trie.Hash() - require.NoError(t, err) - - proof, err := GenerateProof(hash.ToBytes(), [][]byte{[]byte("catapulta"), []byte("catapora")}, memdb) - require.NoError(t, err) - - require.Equal(t, 5, len(proof)) - - pl := []Pair{ - {Key: []byte("catapora"), Value: expectedValue}, - } - - v, err := VerifyProof(proof, hash.ToBytes(), pl) - require.True(t, v) - require.NoError(t, err) -} - -func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][]byte, []Pair) { - t.Helper() - - tmp := t.TempDir() - - memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ - InMemory: true, - DataDir: tmp, - }) - require.NoError(t, err) - - trie := NewEmptyTrie() - for _, e := range entries { - trie.Put(e.Key, e.Value) - } - - err = trie.Store(memdb) - require.NoError(t, err) - - root := trie.root.HashDigest - proof, err := GenerateProof(root, keys, memdb) - require.NoError(t, err) - - items := make([]Pair, len(keys)) - for idx, key := range keys { - value := trie.Get(key) - require.NotNil(t, value) - - items[idx] = Pair{ - Key: key, - Value: value, - } - } - - return root, proof, items -} - -func TestVerifyProof_ShouldReturnTrue(t *testing.T) { - t.Parallel() - - entries := []Pair{ - {Key: []byte("alpha"), Value: make([]byte, 32)}, - {Key: []byte("bravo"), Value: []byte("bravo")}, - {Key: []byte("do"), Value: []byte("verb")}, - {Key: []byte("dogea"), Value: []byte("puppy")}, - {Key: []byte("dogeb"), Value: []byte("puppy")}, - {Key: []byte("horse"), Value: []byte("stallion")}, - {Key: []byte("house"), Value: []byte("building")}, - } - - keys := [][]byte{ - []byte("do"), - []byte("dogea"), - []byte("dogeb"), - } - - root, proof, pairs := testGenerateProof(t, entries, keys) - v, err := VerifyProof(proof, root, pairs) - - require.NoError(t, err) - require.True(t, v) -} - -func TestVerifyProof_ShouldReturnDuplicateKeysError(t *testing.T) { - t.Parallel() - - pl := []Pair{ - {Key: []byte("do"), Value: []byte("verb")}, - {Key: []byte("do"), Value: []byte("puppy")}, - } - - v, err := VerifyProof([][]byte{}, []byte{}, pl) - require.False(t, v) - require.Error(t, err, ErrDuplicateKeys) -} - -func TestVerifyProof_ShouldReturnTrueWithouCompareValues(t *testing.T) { - t.Parallel() - - entries := []Pair{ - {Key: []byte("alpha"), Value: make([]byte, 32)}, - {Key: []byte("bravo"), Value: []byte("bravo")}, - {Key: []byte("do"), Value: []byte("verb")}, - {Key: []byte("dog"), Value: []byte("puppy")}, - {Key: []byte("doge"), Value: make([]byte, 32)}, - {Key: []byte("horse"), Value: []byte("stallion")}, - {Key: []byte("house"), Value: []byte("building")}, - } - - keys := [][]byte{ - []byte("do"), - []byte("dog"), - []byte("doge"), - } - - root, proof, _ := testGenerateProof(t, entries, keys) - - pl := []Pair{ - {Key: []byte("do"), Value: nil}, - {Key: []byte("dog"), Value: nil}, - {Key: []byte("doge"), Value: nil}, - } - - v, err := VerifyProof(proof, root, pl) - require.True(t, v) - require.NoError(t, err) -} - -func TestBranchNodes_SameHash_DifferentPaths_GenerateAndVerifyProof(t *testing.T) { - value := []byte("somevalue") - entries := []Pair{ - {Key: []byte("d"), Value: value}, - {Key: []byte("b"), Value: value}, - {Key: []byte("dxyz"), Value: value}, - {Key: []byte("bxyz"), Value: value}, - {Key: []byte("dxyzi"), Value: value}, - {Key: []byte("bxyzi"), Value: value}, - } - - keys := [][]byte{ - []byte("d"), - []byte("b"), - []byte("dxyz"), - []byte("bxyz"), - []byte("dxyzi"), - []byte("bxyzi"), - } - - root, proof, pairs := testGenerateProof(t, entries, keys) - - ok, err := VerifyProof(proof, root, pairs) - require.NoError(t, err) - require.True(t, ok) -} - -func TestLeafNodes_SameHash_DifferentPaths_GenerateAndVerifyProof(t *testing.T) { - tmp := t.TempDir() - - memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ - InMemory: true, - DataDir: tmp, - }) - require.NoError(t, err) - - var ( - value = []byte("somevalue") - key1 = []byte("worlda") - key2 = []byte("worldb") - ) - - tt := NewEmptyTrie() - tt.Put(key1, value) - tt.Put(key2, value) - - err = tt.Store(memdb) - require.NoError(t, err) - - hash, err := tt.Hash() - require.NoError(t, err) - - proof, err := GenerateProof(hash.ToBytes(), [][]byte{key1, key2}, memdb) - require.NoError(t, err) - - pairs := []Pair{ - {Key: key1, Value: value}, - {Key: key2, Value: value}, - } - - ok, err := VerifyProof(proof, hash.ToBytes(), pairs) - require.NoError(t, err) - require.True(t, ok) -}