diff --git a/polygon/sync/heimdall.go b/polygon/sync/heimdall.go index c0d7a62d076..31617924c8e 100644 --- a/polygon/sync/heimdall.go +++ b/polygon/sync/heimdall.go @@ -20,6 +20,9 @@ type Heimdall interface { OnMilestoneEvent(ctx context.Context, callback func(*milestone.Milestone)) error } +// ErrIncompleteMilestoneRange happens when FetchMilestones is called with an old start block because old milestones are evicted +var ErrIncompleteMilestoneRange = errors.New("milestone range doesn't contain the start block") + type HeimdallImpl struct { client heimdall.IHeimdallClient pollDelay time.Duration @@ -50,6 +53,10 @@ func cmpBlockNumToCheckpointRange(n uint64, c *checkpoint.Checkpoint) int { return cmpNumToRange(n, c.StartBlock, c.EndBlock) } +func cmpBlockNumToMilestoneRange(n uint64, m *milestone.Milestone) int { + return cmpNumToRange(n, m.StartBlock, m.EndBlock) +} + func reverse[T any](s []T) { for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { s[i], s[j] = s[j], s[i] @@ -71,7 +78,7 @@ func (impl *HeimdallImpl) FetchCheckpoints(ctx context.Context, start uint64) ([ } cmpResult := cmpBlockNumToCheckpointRange(start, c) - // the start block is past than the last checkpoint + // the start block is past the last checkpoint if cmpResult > 0 { return nil, nil } @@ -89,7 +96,39 @@ func (impl *HeimdallImpl) FetchCheckpoints(ctx context.Context, start uint64) ([ } func (impl *HeimdallImpl) FetchMilestones(ctx context.Context, start uint64) ([]*milestone.Milestone, error) { - panic("not implemented") + count, err := impl.client.FetchMilestoneCount(ctx) + if err != nil { + return nil, err + } + + var milestones []*milestone.Milestone + + for i := count; i >= 1; i-- { + m, err := impl.client.FetchMilestone(ctx, i) + if err != nil { + if errors.Is(err, heimdall.ErrNotInMilestoneList) { + reverse(milestones) + return milestones, ErrIncompleteMilestoneRange + } + return nil, err + } + + cmpResult := cmpBlockNumToMilestoneRange(start, m) + // the start block is past the last milestone + if cmpResult > 0 { + return nil, nil + } + + milestones = append(milestones, m) + + // the checkpoint contains the start block + if cmpResult == 0 { + break + } + } + + reverse(milestones) + return milestones, nil } func (impl *HeimdallImpl) FetchSpan(ctx context.Context) (*span.HeimdallSpan, error) { diff --git a/polygon/sync/heimdall_test.go b/polygon/sync/heimdall_test.go index ebe4328e13f..a5105aa337e 100644 --- a/polygon/sync/heimdall_test.go +++ b/polygon/sync/heimdall_test.go @@ -3,6 +3,7 @@ package sync import ( "context" "github.com/golang/mock/gomock" + heimdall_client "github.com/ledgerwatch/erigon/consensus/bor/heimdall" "github.com/ledgerwatch/erigon/consensus/bor/heimdall/checkpoint" "github.com/ledgerwatch/erigon/consensus/bor/heimdall/milestone" heimdall_mock "github.com/ledgerwatch/erigon/consensus/bor/heimdall/mock" @@ -73,6 +74,22 @@ func (test heimdallTest) setupCheckpoints(count int) []*checkpoint.Checkpoint { return expectedCheckpoints } +func (test heimdallTest) setupMilestones(count int) []*milestone.Milestone { + var expectedMilestones []*milestone.Milestone + for i := 0; i < count; i++ { + m := makeMilestone(uint64(i*16), 16) + expectedMilestones = append(expectedMilestones, m) + } + + client := test.client + client.EXPECT().FetchMilestoneCount(gomock.Any()).Return(int64(len(expectedMilestones)), nil) + client.EXPECT().FetchMilestone(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, number int64) (*milestone.Milestone, error) { + return expectedMilestones[number-1], nil + }).AnyTimes() + + return expectedMilestones +} + func TestFetchCheckpoints1(t *testing.T) { test := newHeimdallTest(t) expectedCheckpoint := test.setupCheckpoints(1)[0] @@ -121,6 +138,83 @@ func TestFetchCheckpointsMiddleStart(t *testing.T) { } } +func TestFetchMilestones1(t *testing.T) { + test := newHeimdallTest(t) + expectedMilestone := test.setupMilestones(1)[0] + + milestones, err := test.heimdall.FetchMilestones(test.ctx, 0) + require.Nil(t, err) + + require.Equal(t, 1, len(milestones)) + assert.Equal(t, expectedMilestone.Timestamp, milestones[0].Timestamp) +} + +func TestFetchMilestonesPastLast(t *testing.T) { + test := newHeimdallTest(t) + _ = test.setupMilestones(1)[0] + + milestones, err := test.heimdall.FetchMilestones(test.ctx, 500) + require.Nil(t, err) + + require.Equal(t, 0, len(milestones)) +} + +func TestFetchMilestones10(t *testing.T) { + test := newHeimdallTest(t) + expectedMilestones := test.setupMilestones(10) + + milestones, err := test.heimdall.FetchMilestones(test.ctx, 0) + require.Nil(t, err) + + require.Equal(t, len(expectedMilestones), len(milestones)) + for i := 0; i < len(milestones); i++ { + assert.Equal(t, expectedMilestones[i].StartBlock.Uint64(), milestones[i].StartBlock.Uint64()) + } +} + +func TestFetchMilestonesMiddleStart(t *testing.T) { + test := newHeimdallTest(t) + expectedMilestones := test.setupMilestones(10) + const offset = 6 + + milestones, err := test.heimdall.FetchMilestones(test.ctx, expectedMilestones[offset].StartBlock.Uint64()) + require.Nil(t, err) + + require.Equal(t, len(expectedMilestones)-offset, len(milestones)) + for i := 0; i < len(milestones); i++ { + assert.Equal(t, expectedMilestones[offset+i].StartBlock.Uint64(), milestones[i].StartBlock.Uint64()) + } +} + +func TestFetchMilestonesStartingBeforeEvictionPoint(t *testing.T) { + test := newHeimdallTest(t) + + var expectedMilestones []*milestone.Milestone + for i := 0; i < 20; i++ { + m := makeMilestone(uint64(i*16), 16) + expectedMilestones = append(expectedMilestones, m) + } + const keptMilestones = 5 + + client := test.client + client.EXPECT().FetchMilestoneCount(gomock.Any()).Return(int64(len(expectedMilestones)), nil) + client.EXPECT().FetchMilestone(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, number int64) (*milestone.Milestone, error) { + if int(number) <= len(expectedMilestones)-keptMilestones { + return nil, heimdall_client.ErrNotInMilestoneList + } + return expectedMilestones[number-1], nil + }).AnyTimes() + + milestones, err := test.heimdall.FetchMilestones(test.ctx, 0) + require.NotNil(t, err) + require.ErrorIs(t, err, ErrIncompleteMilestoneRange) + + require.Equal(t, keptMilestones, len(milestones)) + for i := 0; i < len(milestones); i++ { + assert.Equal(t, expectedMilestones[len(expectedMilestones)-len(milestones)+i].StartBlock.Uint64(), milestones[i].StartBlock.Uint64()) + } +} + func TestOnMilestoneEvent(t *testing.T) { test := newHeimdallTest(t)