Skip to content

Commit

Permalink
Fix cold state generation (#5770)
Browse files Browse the repository at this point in the history
  • Loading branch information
terencechain authored May 6, 2020
1 parent d5b1f9f commit ef4dead
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 5 deletions.
23 changes: 21 additions & 2 deletions beacon-chain/state/stategen/cold.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,32 @@ func (s *State) loadColdStateByRoot(ctx context.Context, blockRoot [32]byte) (*s
return nil, errors.Wrap(err, "could not get state summary")
}

return s.ComputeStateUpToSlot(ctx, summary.Slot)
return s.loadColdStateBySlot(ctx, summary.Slot)
}

// This loads a cold state by slot.
func (s *State) loadColdStateBySlot(ctx context.Context, slot uint64) (*state.BeaconState, error) {
ctx, span := trace.StartSpan(ctx, "stateGen.loadColdStateBySlot")
defer span.End()

return s.ComputeStateUpToSlot(ctx, slot)
if slot == 0 {
return s.beaconDB.GenesisState(ctx)
}

archivedState, err := s.archivedState(ctx, slot)
if err != nil {
return nil, err
}
if archivedState == nil {
archivedRoot := s.archivedRoot(ctx, slot)
archivedState, err = s.recoverStateByRoot(ctx, archivedRoot)
if err != nil {
return nil, err
}
if archivedState == nil {
return nil, errUnknownState
}
}

return s.processStateUpTo(ctx, archivedState, slot)
}
6 changes: 6 additions & 0 deletions beacon-chain/state/stategen/cold_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ func TestLoadColdStateByRoot_CanGet(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveArchivedPointRoot(ctx, blkRoot, 0); err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveGenesisBlockRoot(ctx, blkRoot); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -112,6 +115,9 @@ func TestLoadColdStateBySlot_CanGet(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveArchivedPointRoot(ctx, blkRoot, 0); err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveGenesisBlockRoot(ctx, blkRoot); err != nil {
t.Fatal(err)
}
Expand Down
9 changes: 6 additions & 3 deletions beacon-chain/state/stategen/getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func TestStateByRoot_ColdState(t *testing.T) {
if err := beaconState.SetSlot(1); err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveArchivedPointRoot(ctx, bRoot, 0); err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveState(ctx, beaconState, bRoot); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -155,14 +158,14 @@ func TestStateBySlot_ColdState(t *testing.T) {
if err := db.SaveState(ctx, beaconState, bRoot); err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveArchivedPointRoot(ctx, bRoot, 0); err != nil {
t.Fatal(err)
}
if err := db.SaveGenesisBlockRoot(ctx, bRoot); err != nil {
t.Fatal(err)
}

r := [32]byte{}
if err := service.beaconDB.SaveArchivedPointRoot(ctx, r, 0); err != nil {
t.Fatal(err)
}
if err := service.beaconDB.SaveArchivedPointRoot(ctx, r, 1); err != nil {
t.Fatal(err)
}
Expand Down
75 changes: 75 additions & 0 deletions beacon-chain/state/stategen/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
"github.com/prysmaticlabs/prysm/shared/featureconfig"
"github.com/prysmaticlabs/prysm/shared/params"
"go.opencensus.io/trace"
)

Expand Down Expand Up @@ -306,3 +307,77 @@ func (s *State) genesisRoot(ctx context.Context) ([32]byte, error) {
}
return stateutil.BlockRoot(b.Block)
}

// This retrieves the archived root in the DB.
func (s *State) archivedRoot(ctx context.Context, slot uint64) [32]byte {
archivedIndex := uint64(0)
if slot/params.BeaconConfig().SlotsPerArchivedPoint > 1 {
archivedIndex = slot/params.BeaconConfig().SlotsPerArchivedPoint - 1
}
return s.beaconDB.ArchivedPointRoot(ctx, archivedIndex)
}

// This retrieves the archived state in the DB.
func (s *State) archivedState(ctx context.Context, slot uint64) (*state.BeaconState, error) {
archivedRoot := s.archivedRoot(ctx, slot)
return s.beaconDB.State(ctx, archivedRoot)
}

// This recomputes a state given the block root.
func (s *State) recoverStateByRoot(ctx context.Context, root [32]byte) (*state.BeaconState, error) {
ctx, span := trace.StartSpan(ctx, "stateGen.recoverStateByRoot")
defer span.End()

lastAncestorState, err := s.lastAncestorState(ctx, root)
if err != nil {
return nil, err
}
if lastAncestorState == nil {
return nil, errUnknownState
}

targetBlk, err := s.beaconDB.Block(ctx, root)
if err != nil {
return nil, err
}
if targetBlk == nil {
return nil, errUnknownBlock
}
blks, err := s.LoadBlocks(ctx, lastAncestorState.Slot()+1, targetBlk.Block.Slot, root)
if err != nil {
return nil, errors.Wrap(err, "could not load blocks for cold state using root")
}

return s.ReplayBlocks(ctx, lastAncestorState, blks, targetBlk.Block.Slot)
}

// This processes a state up to input slot.
func (s *State) processStateUpTo(ctx context.Context, state *state.BeaconState, slot uint64) (*state.BeaconState, error) {
ctx, span := trace.StartSpan(ctx, "stateGen.processStateUpTo")
defer span.End()

// Short circuit if the slot is already less than pre state.
if state.Slot() >= slot {
return state, nil
}

lastBlockRoot, lastBlockSlot, err := s.lastSavedBlock(ctx, slot)
if err != nil {
return nil, errors.Wrap(err, "could not get last saved block")
}
// Short circuit if no block was saved, replay using slots only.
if lastBlockSlot == 0 {
return s.ReplayBlocks(ctx, state, []*ethpb.SignedBeaconBlock{}, slot)
}

blks, err := s.LoadBlocks(ctx, state.Slot()+1, lastBlockSlot, lastBlockRoot)
if err != nil {
return nil, errors.Wrap(err, "could not load blocks")
}
state, err = s.ReplayBlocks(ctx, state, blks, slot)
if err != nil {
return nil, errors.Wrap(err, "could not replay blocks")
}

return state, nil
}
73 changes: 73 additions & 0 deletions beacon-chain/state/stategen/replay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,79 @@ func TestLastSavedState_NoSavedBlockState(t *testing.T) {
}
}

func TestArchivedRoot_CanGet(t *testing.T) {
ctx := context.Background()
db := testDB.SetupDB(t)
service := New(db, cache.NewStateSummaryCache())

r := [32]byte{'a'}
if err := db.SaveArchivedPointRoot(ctx, r, 0); err != nil {
t.Fatal(err)
}
got := service.archivedRoot(ctx, params.BeaconConfig().SlotsPerArchivedPoint)
if r != got {
t.Error("Did not get wanted root")
}
}

func TestArchivedState_CanGet(t *testing.T) {
ctx := context.Background()
db := testDB.SetupDB(t)
service := New(db, cache.NewStateSummaryCache())

r := [32]byte{'a'}
if err := db.SaveArchivedPointRoot(ctx, r, 0); err != nil {
t.Fatal(err)
}
beaconState, _ := testutil.DeterministicGenesisState(t, 32)
if err := db.SaveState(ctx, beaconState, r); err != nil {
t.Fatal(err)
}
got, err := service.archivedState(ctx, params.BeaconConfig().SlotsPerArchivedPoint)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got.InnerStateUnsafe(), beaconState.InnerStateUnsafe()) {
t.Error("Did not get wanted state")
}
}

func TestProcessStateUpToSlot_CanExitEarly(t *testing.T) {
ctx := context.Background()
db := testDB.SetupDB(t)

service := New(db, cache.NewStateSummaryCache())
beaconState, _ := testutil.DeterministicGenesisState(t, 32)
if err := beaconState.SetSlot(params.BeaconConfig().SlotsPerEpoch + 1); err != nil {
t.Fatal(err)
}
s, err := service.processStateUpTo(ctx, beaconState, params.BeaconConfig().SlotsPerEpoch)
if err != nil {
t.Fatal(err)
}

if s.Slot() != params.BeaconConfig().SlotsPerEpoch+1 {
t.Error("Did not receive correct processed state")
}
}

func TestProcessStateUpToSlot_CanProcess(t *testing.T) {
ctx := context.Background()
db := testDB.SetupDB(t)

service := New(db, cache.NewStateSummaryCache())
beaconState, _ := testutil.DeterministicGenesisState(t, 32)

s, err := service.processStateUpTo(ctx, beaconState, params.BeaconConfig().SlotsPerEpoch+1)
if err != nil {
t.Fatal(err)
}

if s.Slot() != params.BeaconConfig().SlotsPerEpoch+1 {
t.Error("Did not receive correct processed state")
}
}

// tree1 constructs the following tree:
// B0 - B1 - - B3 -- B5
// \- B2 -- B4 -- B6 ----- B8
Expand Down

0 comments on commit ef4dead

Please sign in to comment.