diff --git a/dot/state/helpers_test.go b/dot/state/helpers_test.go index 10029e9f01..1a92a9791a 100644 --- a/dot/state/helpers_test.go +++ b/dot/state/helpers_test.go @@ -14,12 +14,13 @@ import ( "github.com/ChainSafe/gossamer/lib/runtime" "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/lib/utils" + lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache" "github.com/stretchr/testify/require" ) func newTriesEmpty() *Tries { return &Tries{ - rootToTrie: make(map[common.Hash]*trie.Trie), + rootToTrie: lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries), triesGauge: triesGauge, setCounter: setCounter, deleteCounter: deleteCounter, diff --git a/dot/state/tries.go b/dot/state/tries.go index d7c87bc625..bcc6c0f090 100644 --- a/dot/state/tries.go +++ b/dot/state/tries.go @@ -4,10 +4,9 @@ package state import ( - "sync" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" + lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -30,11 +29,12 @@ var ( }) ) +const MaxInMemoryTries = 100 + // Tries is a thread safe map of root hash // to trie. type Tries struct { - rootToTrie map[common.Hash]*trie.Trie - mapMutex sync.RWMutex + rootToTrie *lrucache.LRUCache[common.Hash, *trie.Trie] triesGauge prometheus.Gauge setCounter prometheus.Counter deleteCounter prometheus.Counter @@ -43,8 +43,10 @@ type Tries struct { // NewTries creates a new thread safe map of root hash // to trie. func NewTries() (tries *Tries) { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + return &Tries{ - rootToTrie: make(map[common.Hash]*trie.Trie), + rootToTrie: cache, triesGauge: triesGauge, setCounter: setCounter, deleteCounter: deleteCounter, @@ -66,41 +68,27 @@ func (t *Tries) SetTrie(tr *trie.Trie) { // softSet sets the given trie at the given root hash // in the memory map only if it is not already set. func (t *Tries) softSet(root common.Hash, trie *trie.Trie) { - t.mapMutex.Lock() - defer t.mapMutex.Unlock() - - _, has := t.rootToTrie[root] - if has { - return + if t.rootToTrie.SoftPut(root, trie) { + t.triesGauge.Inc() + t.setCounter.Inc() } - - t.triesGauge.Inc() - t.setCounter.Inc() - t.rootToTrie[root] = trie } func (t *Tries) delete(root common.Hash) { - t.mapMutex.Lock() - defer t.mapMutex.Unlock() - delete(t.rootToTrie, root) - // Note we use .Set instead of .Dec in case nothing - // was deleted since nothing existed at the hash given. - t.triesGauge.Set(float64(len(t.rootToTrie))) - t.deleteCounter.Inc() + if t.rootToTrie.Delete(root) { + t.triesGauge.Dec() + t.deleteCounter.Inc() + } } // get retrieves the trie corresponding to the root hash given // from the in-memory thread safe map. func (t *Tries) get(root common.Hash) (tr *trie.Trie) { - t.mapMutex.RLock() - defer t.mapMutex.RUnlock() - return t.rootToTrie[root] + return t.rootToTrie.Get(root) } // len returns the current numbers of tries // stored in the in-memory map. func (t *Tries) len() int { - t.mapMutex.RLock() - defer t.mapMutex.RUnlock() - return len(t.rootToTrie) + return t.rootToTrie.Len() } diff --git a/dot/state/tries_test.go b/dot/state/tries_test.go index cff32bfac1..b8284e66e7 100644 --- a/dot/state/tries_test.go +++ b/dot/state/tries_test.go @@ -9,17 +9,20 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" + lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) +var emptyTrie = trie.NewEmptyTrie() + func Test_NewTries(t *testing.T) { t.Parallel() rootToTrie := NewTries() expectedTries := &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{}, + rootToTrie: lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries), triesGauge: triesGauge, setCounter: setCounter, deleteCounter: deleteCounter, @@ -35,13 +38,12 @@ func Test_Tries_SetEmptyTrie(t *testing.T) { tries.SetEmptyTrie() expectedTries := &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{ - trie.EmptyHash: trie.NewEmptyTrie(), - }, + rootToTrie: lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries), triesGauge: triesGauge, setCounter: setCounter, deleteCounter: deleteCounter, } + expectedTries.rootToTrie.Put(trie.EmptyHash, trie.NewEmptyTrie()) assert.Equal(t, expectedTries, tries) } @@ -58,45 +60,54 @@ func Test_Tries_SetTrie(t *testing.T) { tries.SetTrie(tr) expectedTries := &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{ - tr.MustHash(trie.NoMaxInlineValueSize): tr, - }, + rootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(tr.MustHash(trie.NoMaxInlineValueSize), tr) + return cache + }(), triesGauge: triesGauge, setCounter: setCounter, deleteCounter: deleteCounter, } + expectedTries.rootToTrie.Put(tr.MustHash(trie.NoMaxInlineValueSize), tr) + assert.Equal(t, expectedTries, tries) } func Test_Tries_softSet(t *testing.T) { t.Parallel() - testCases := map[string]struct { - rootToTrie map[common.Hash]*trie.Trie + rootToTrie *lrucache.LRUCache[common.Hash, *trie.Trie] root common.Hash trie *trie.Trie triesGaugeInc bool - expectedRootToTrie map[common.Hash]*trie.Trie + expectedRootToTrie *lrucache.LRUCache[common.Hash, *trie.Trie] }{ "set_new_in_map": { - rootToTrie: map[common.Hash]*trie.Trie{}, + rootToTrie: lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries), root: common.Hash{1, 2, 3}, trie: trie.NewEmptyTrie(), triesGaugeInc: true, - expectedRootToTrie: map[common.Hash]*trie.Trie{ - {1, 2, 3}: trie.NewEmptyTrie(), - }, + expectedRootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{1, 2, 3}, trie.NewEmptyTrie()) + return cache + }(), }, "do_not_override_in_map": { - rootToTrie: map[common.Hash]*trie.Trie{ - {1, 2, 3}: {}, - }, + rootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{1, 2, 3}, emptyTrie) + return cache + }(), root: common.Hash{1, 2, 3}, - trie: trie.NewEmptyTrie(), - expectedRootToTrie: map[common.Hash]*trie.Trie{ - {1, 2, 3}: {}, - }, + trie: emptyTrie, + expectedRootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{1, 2, 3}, emptyTrie) + return cache + }(), }, } @@ -133,34 +144,42 @@ func Test_Tries_delete(t *testing.T) { t.Parallel() testCases := map[string]struct { - rootToTrie map[common.Hash]*trie.Trie + rootToTrie *lrucache.LRUCache[common.Hash, *trie.Trie] root common.Hash - deleteCounterInc bool - expectedRootToTrie map[common.Hash]*trie.Trie + counterUpdated bool + expectedRootToTrie *lrucache.LRUCache[common.Hash, *trie.Trie] triesGaugeSet float64 }{ "not_found": { - rootToTrie: map[common.Hash]*trie.Trie{ - {3, 4, 5}: {}, - }, + rootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{3, 4, 5}, emptyTrie) + return cache + }(), root: common.Hash{1, 2, 3}, triesGaugeSet: 1, - expectedRootToTrie: map[common.Hash]*trie.Trie{ - {3, 4, 5}: {}, - }, - deleteCounterInc: true, + expectedRootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{3, 4, 5}, emptyTrie) + return cache + }(), + counterUpdated: false, }, "deleted": { - rootToTrie: map[common.Hash]*trie.Trie{ - {1, 2, 3}: {}, - {3, 4, 5}: {}, - }, + rootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{1, 2, 3}, emptyTrie) + cache.Put(common.Hash{3, 4, 5}, emptyTrie) + return cache + }(), root: common.Hash{1, 2, 3}, triesGaugeSet: 1, - expectedRootToTrie: map[common.Hash]*trie.Trie{ - {3, 4, 5}: {}, - }, - deleteCounterInc: true, + expectedRootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{3, 4, 5}, emptyTrie) + return cache + }(), + counterUpdated: true, }, } @@ -170,11 +189,11 @@ func Test_Tries_delete(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) triesGauge := NewMockGauge(ctrl) - triesGauge.EXPECT().Set(testCase.triesGaugeSet) - deleteCounter := NewMockCounter(ctrl) - if testCase.deleteCounterInc { + + if testCase.counterUpdated { deleteCounter.EXPECT().Inc() + triesGauge.EXPECT().Dec() } tries := &Tries{ @@ -189,6 +208,7 @@ func Test_Tries_delete(t *testing.T) { }) } } + func Test_Tries_get(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -202,12 +222,15 @@ func Test_Tries_get(t *testing.T) { }{ "found_in_map": { tries: &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{ - {1, 2, 3}: trie.NewTrie(&node.Node{ + rootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + tr := trie.NewTrie(&node.Node{ PartialKey: []byte{1, 2, 3}, StorageValue: []byte{1}, - }, db), - }, + }, db) + cache.Put(common.Hash{1, 2, 3}, tr) + return cache + }(), }, root: common.Hash{1, 2, 3}, trie: trie.NewTrie(&node.Node{ @@ -218,7 +241,7 @@ func Test_Tries_get(t *testing.T) { "not_found_in_map": { // similar to not found in database tries: &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{}, + rootToTrie: lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries), }, root: common.Hash{1, 2, 3}, }, @@ -245,14 +268,16 @@ func Test_Tries_len(t *testing.T) { }{ "empty_map": { tries: &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{}, + rootToTrie: lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries), }, }, "non_empty_map": { tries: &Tries{ - rootToTrie: map[common.Hash]*trie.Trie{ - {1, 2, 3}: {}, - }, + rootToTrie: func() *lrucache.LRUCache[common.Hash, *trie.Trie] { + cache := lrucache.NewLRUCache[common.Hash, *trie.Trie](MaxInMemoryTries) + cache.Put(common.Hash{1, 2, 3}, emptyTrie) + return cache + }(), }, length: 1, }, diff --git a/lib/trie/cache/cache.go b/lib/trie/cache/cache.go new file mode 100644 index 0000000000..3b13e2d1c8 --- /dev/null +++ b/lib/trie/cache/cache.go @@ -0,0 +1,148 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package cache + +import ( + "bytes" + "container/list" + "sync" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" +) + +const DefaultCapacity = 8 * 1024 * 1024 // bytes + +// LRUCache represents the LRU cache. +type LRUCache struct { + sync.RWMutex + capacity uint + memoryUsed uint + cache map[common.Hash]*list.Element + lruList *list.List +} + +// Entry represents an item in the cache. +type Entry struct { + key common.Hash + value trie.Node +} + +// NewLRUCache creates a new LRU cache with the specified capacity. +func NewLRUCache(capacity uint) *LRUCache { + if capacity < 1 { + capacity = DefaultCapacity + } + + return &LRUCache{ + capacity: capacity, + cache: make(map[common.Hash]*list.Element), + lruList: list.New(), + } +} + +// Get retrieves the value associated with the given key from the cache. +func (c *LRUCache) Get(key common.Hash) trie.Node { + c.RLock() + defer c.RUnlock() + + if elem, exists := c.cache[key]; exists { + c.lruList.MoveToFront(elem) + return elem.Value.(*Entry).value + } + + var zeroV trie.Node + return zeroV +} + +// Has checks if the cache contains the given key. +func (c *LRUCache) Has(key common.Hash) bool { + c.RLock() + defer c.RUnlock() + + _, exists := c.cache[key] + return exists +} + +// Put adds a key-value pair to the cache. +func (c *LRUCache) Put(key common.Hash, value trie.Node) error { + c.Lock() + defer c.Unlock() + + return c.insertEntry(key, value) +} + +// SoftPut adds a key-value pair to the cache if it does not already exist. +func (c *LRUCache) SoftPut(key common.Hash, value trie.Node) (bool, error) { + c.Lock() + defer c.Unlock() + + if _, exists := c.cache[key]; exists { + return false, nil + } + + err := c.insertEntry(key, value) + if err != nil { + return false, err + } + return true, nil +} + +// Delete removes the given key from the cache. +func (c *LRUCache) Delete(key common.Hash) bool { + c.Lock() + defer c.Unlock() + + val, exists := c.cache[key] + if !exists { + return false + } + + c.lruList.Remove(val) + + delete(c.cache, key) + return true +} + +// Len returns the number of items in the cache. +func (c *LRUCache) Len() int { + c.Lock() + defer c.Unlock() + + return len(c.cache) +} + +func (c *LRUCache) insertEntry(key common.Hash, value trie.Node) error { + // If the key already exists in the cache, update its value and move it to the front. + if elem, exists := c.cache[key]; exists { + elem.Value.(*Entry).value = value + c.lruList.MoveToFront(elem) + return nil + } + + buffer := bytes.NewBuffer(nil) + err := value.Encode(buffer, trie.NoMaxInlineValueSize) + if err != nil { + return err + } + + // If the cache is full, remove the least recently used item (from the back of the list). + if buffer.Len() >= int(c.capacity) { + // Get the least recently used item (back of the list). + lastElem := c.lruList.Back() + if lastElem != nil { + delete(c.cache, lastElem.Value.(*Entry).key) + c.lruList.Remove(lastElem) + } + } + + // Add the new key-value pair to the cache (at the front of the list). + newEntry := &Entry{key: key, value: value} + newElem := c.lruList.PushFront(newEntry) + c.cache[key] = newElem + + c.memoryUsed += uint(buffer.Len()) + + return nil +}