Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor state tests to always use initialized state #3310

Merged
merged 7 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions vms/platformvm/genesis/genesistest/genesis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package genesistest

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/utils/units"
"github.com/ava-labs/avalanchego/vms/components/avax"
"github.com/ava-labs/avalanchego/vms/platformvm/genesis"
"github.com/ava-labs/avalanchego/vms/platformvm/reward"
"github.com/ava-labs/avalanchego/vms/platformvm/txs"
"github.com/ava-labs/avalanchego/vms/secp256k1fx"
)

var (
AVAXAssetID = ids.GenerateTestID()
AVAXAsset = avax.Asset{ID: AVAXAssetID}

ValidatorNodeID = ids.GenerateTestNodeID()
Time = time.Now().Round(time.Second)
TimeUnix = uint64(Time.Unix())
ValidatorDuration = 28 * 24 * time.Hour
ValidatorEndTime = Time.Add(ValidatorDuration)
ValidatorEndTimeUnix = uint64(ValidatorEndTime.Unix())
ValidatorWeight = units.Avax
ValidatorRewardsOwner = &secp256k1fx.OutputOwners{}
ValidatorDelegationShares uint32 = reward.PercentDenominator

XChainName = "x"

InitialBalance = units.Schmeckle
InitialSupply = ValidatorWeight + InitialBalance
)

func New(t testing.TB) *genesis.Genesis {
require := require.New(t)

genesisValidator := &txs.AddValidatorTx{
Validator: txs.Validator{
NodeID: ValidatorNodeID,
Start: TimeUnix,
End: ValidatorEndTimeUnix,
Wght: ValidatorWeight,
},
StakeOuts: []*avax.TransferableOutput{
{
Asset: AVAXAsset,
Out: &secp256k1fx.TransferOutput{
Amt: ValidatorWeight,
},
},
},
RewardsOwner: ValidatorRewardsOwner,
DelegationShares: ValidatorDelegationShares,
}
genesisValidatorTx := &txs.Tx{Unsigned: genesisValidator}
require.NoError(genesisValidatorTx.Initialize(txs.Codec))

genesisChain := &txs.CreateChainTx{
SubnetID: constants.PrimaryNetworkID,
ChainName: XChainName,
VMID: constants.AVMID,
SubnetAuth: &secp256k1fx.Input{},
}
genesisChainTx := &txs.Tx{Unsigned: genesisChain}
require.NoError(genesisChainTx.Initialize(txs.Codec))

return &genesis.Genesis{
UTXOs: []*genesis.UTXO{
{
UTXO: avax.UTXO{
UTXOID: avax.UTXOID{
TxID: AVAXAssetID,
OutputIndex: 0,
},
Asset: AVAXAsset,
Out: &secp256k1fx.TransferOutput{
Amt: InitialBalance,
},
},
Message: nil,
},
},
Validators: []*txs.Tx{
genesisValidatorTx,
},
Chains: []*txs.Tx{
genesisChainTx,
},
Timestamp: TimeUnix,
InitialSupply: InitialSupply,
}
}

func NewBytes(t testing.TB) []byte {
g := New(t)
genesisBytes, err := genesis.Codec.Marshal(genesis.CodecVersion, g)
require.NoError(t, err)
return genesisBytes
}
19 changes: 10 additions & 9 deletions vms/platformvm/state/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.uber.org/mock/gomock"

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/database/memdb"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils"
"github.com/ava-labs/avalanchego/utils/constants"
Expand All @@ -36,7 +37,7 @@ func TestDiffMissingState(t *testing.T) {
func TestNewDiffOn(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

d, err := NewDiffOn(state)
require.NoError(err)
Expand All @@ -47,7 +48,7 @@ func TestNewDiffOn(t *testing.T) {
func TestDiffFeeState(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

d, err := NewDiffOn(state)
require.NoError(err)
Expand All @@ -68,7 +69,7 @@ func TestDiffFeeState(t *testing.T) {
func TestDiffCurrentSupply(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

d, err := NewDiffOn(state)
require.NoError(err)
Expand Down Expand Up @@ -256,7 +257,7 @@ func TestDiffSubnet(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

// Initialize parent with one subnet
parentStateCreateSubnetTx := &txs.Tx{
Expand Down Expand Up @@ -305,7 +306,7 @@ func TestDiffSubnet(t *testing.T) {
func TestDiffChain(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())
subnetID := ids.GenerateTestID()

// Initialize parent with one chain
Expand Down Expand Up @@ -402,7 +403,7 @@ func TestDiffTx(t *testing.T) {
func TestDiffRewardUTXO(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

// Initialize parent with one reward UTXO
var (
Expand Down Expand Up @@ -531,7 +532,7 @@ func TestDiffSubnetOwner(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

var (
owner1 = fx.NewMockOwner(ctrl)
Expand Down Expand Up @@ -589,7 +590,7 @@ func TestDiffSubnetManager(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

states := NewMockVersions(ctrl)
lastAcceptedID := ids.GenerateTestID()
Expand Down Expand Up @@ -638,7 +639,7 @@ func TestDiffStacking(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

var (
owner1 = fx.NewMockOwner(ctrl)
Expand Down
3 changes: 2 additions & 1 deletion vms/platformvm/state/stakers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/vms/platformvm/genesis/genesistest"
"github.com/ava-labs/avalanchego/vms/platformvm/txs"
)

Expand Down Expand Up @@ -219,7 +220,7 @@ func TestDiffStakersDelegator(t *testing.T) {

func newTestStaker() *Staker {
startTime := time.Now().Round(time.Second)
endTime := startTime.Add(28 * 24 * time.Hour)
endTime := startTime.Add(genesistest.ValidatorDuration)
return &Staker{
TxID: ids.GenerateTestID(),
NodeID: ids.GenerateTestNodeID(),
Expand Down
72 changes: 24 additions & 48 deletions vms/platformvm/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,38 +458,6 @@ func New(
metrics metrics.Metrics,
rewards reward.Calculator,
) (State, error) {
s, err := newState(
db,
metrics,
cfg,
execCfg,
ctx,
metricsReg,
rewards,
)
if err != nil {
return nil, err
}

if err := s.sync(genesisBytes); err != nil {
// Drop any errors on close to return the first error
_ = s.Close()

return nil, err
}

return s, nil
}

func newState(
db database.Database,
metrics metrics.Metrics,
cfg *config.Config,
execCfg *config.ExecutionConfig,
ctx *snow.Context,
metricsReg prometheus.Registerer,
rewards reward.Calculator,
) (*state, error) {
blockIDCache, err := metercacher.New[uint64, ids.ID](
"block_id_cache",
metricsReg,
Expand Down Expand Up @@ -614,7 +582,7 @@ func newState(
return nil, err
}

return &state{
s := &state{
validatorState: newValidatorState(),

validators: cfg.Validators,
Expand Down Expand Up @@ -694,7 +662,16 @@ func newState(
chainDBCache: chainDBCache,

singletonDB: prefixdb.New(SingletonPrefix, baseDB),
}, nil
}

if err := s.sync(genesisBytes); err != nil {
return nil, errors.Join(
err,
s.Close(),
)
}

return s, nil
}

func (s *state) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) {
Expand Down Expand Up @@ -753,15 +730,6 @@ func (s *state) GetPendingStakerIterator() (StakerIterator, error) {
return s.pendingStakers.GetStakerIterator(), nil
}

func (s *state) shouldInit() (bool, error) {
has, err := s.singletonDB.Has(InitializedKey)
return !has, err
}

func (s *state) doneInit() error {
return s.singletonDB.Put(InitializedKey, nil)
}

func (s *state) GetSubnetIDs() ([]ids.ID, error) {
if s.cachedSubnetIDs != nil {
return s.cachedSubnetIDs, nil
Expand Down Expand Up @@ -1751,17 +1719,17 @@ func (s *state) Close() error {
}

func (s *state) sync(genesis []byte) error {
shouldInit, err := s.shouldInit()
wasInitialized, err := isInitialized(s.singletonDB)
if err != nil {
return fmt.Errorf(
"failed to check if the database is initialized: %w",
err,
)
}

// If the database is empty, create the platform chain anew using the
// provided genesis state
if shouldInit {
// If the database wasn't previously initialized, create the platform chain
// anew using the provided genesis state.
if !wasInitialized {
if err := s.init(genesis); err != nil {
return fmt.Errorf(
"failed to initialize the database: %w",
Expand Down Expand Up @@ -1797,7 +1765,7 @@ func (s *state) init(genesisBytes []byte) error {
return err
}

if err := s.doneInit(); err != nil {
if err := markInitialized(s.singletonDB); err != nil {
return err
}

Expand Down Expand Up @@ -2548,6 +2516,14 @@ func (s *state) ReindexBlocks(lock sync.Locker, log logging.Logger) error {
return s.Commit()
}

func markInitialized(db database.KeyValueWriter) error {
return db.Put(InitializedKey, nil)
}

func isInitialized(db database.KeyValueReader) (bool, error) {
return db.Has(InitializedKey)
}

func putFeeState(db database.KeyValueWriter, feeState fee.State) error {
feeStateBytes, err := block.GenesisCodec.Marshal(block.CodecVersion, feeState)
if err != nil {
Expand Down
Loading
Loading