Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test: Replace t.error/fatal with assert/request in [raft_paper_test.go] #182

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 61 additions & 149 deletions raft_paper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ package raft

import (
"fmt"
"reflect"
"sort"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

pb "go.etcd.io/raft/v3/raftpb"
)

Expand Down Expand Up @@ -64,12 +66,8 @@ func testUpdateTermFromMessage(t *testing.T, state StateType) {

r.Step(pb.Message{Type: pb.MsgApp, Term: 2})

if r.Term != 2 {
t.Errorf("term = %d, want %d", r.Term, 2)
}
if r.state != StateFollower {
t.Errorf("state = %v, want %v", r.state, StateFollower)
}
assert.Equal(t, uint64(2), r.Term)
assert.Equal(t, StateFollower, r.state)
}

// TestRejectStaleTermMessage tests that if a server receives a request with
Expand All @@ -88,18 +86,14 @@ func TestRejectStaleTermMessage(t *testing.T) {

r.Step(pb.Message{Type: pb.MsgApp, Term: r.Term - 1})

if called {
t.Errorf("stepFunc called = %v, want %v", called, false)
}
assert.False(t, called)
}

// TestStartAsFollower tests that when servers start up, they begin as followers.
// Reference: section 5.2
func TestStartAsFollower(t *testing.T) {
r := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3)))
if r.state != StateFollower {
t.Errorf("state = %s, want %s", r.state, StateFollower)
}
assert.Equal(t, StateFollower, r.state)
}

// TestLeaderBcastBeat tests that if the leader receives a heartbeat tick,
Expand All @@ -122,13 +116,10 @@ func TestLeaderBcastBeat(t *testing.T) {

msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Term: 1, Type: pb.MsgHeartbeat},
{From: 1, To: 3, Term: 1, Type: pb.MsgHeartbeat},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("msgs = %v, want %v", msgs, wmsgs)
}
}, msgs)
}

func TestFollowerStartElection(t *testing.T) {
Expand Down Expand Up @@ -164,24 +155,16 @@ func testNonleaderStartElection(t *testing.T, state StateType) {
}
r.advanceMessagesAfterAppend()

if r.Term != 2 {
t.Errorf("term = %d, want 2", r.Term)
}
if r.state != StateCandidate {
t.Errorf("state = %s, want %s", r.state, StateCandidate)
}
if !r.trk.Votes[r.id] {
t.Errorf("vote for self = false, want true")
}
assert.Equal(t, uint64(2), r.Term)
assert.Equal(t, StateCandidate, r.state)
assert.True(t, r.trk.Votes[r.id])

msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For another PR: consider removing messageSlice, and using slices.SortFunc with some stable comparator.

Copy link
Contributor Author

@MrDXY MrDXY Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be what you had in mind?

slices.SortStableFunc(data, orderMessageAsc)
func sortMessageAsc(a pb.Message, b pb.Message) int {
	strA := fmt.Sprint(a)
	strB := fmt.Sprint(b)
	if strA > strB {
		return 1
	} else if strA < strB {
		return -1
	} else {
		return 0
	}
}

wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Term: 2, Type: pb.MsgVote},
{From: 1, To: 3, Term: 2, Type: pb.MsgVote},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("msgs = %v, want %v", msgs, wmsgs)
}
}, msgs)
}

// TestLeaderElectionInOneRoundRPC tests all cases that may happen in
Expand Down Expand Up @@ -224,12 +207,8 @@ func TestLeaderElectionInOneRoundRPC(t *testing.T) {
r.Step(pb.Message{From: id, To: 1, Term: r.Term, Type: pb.MsgVoteResp, Reject: !vote})
}

if r.state != tt.state {
t.Errorf("#%d: state = %s, want %s", i, r.state, tt.state)
}
if g := r.Term; g != 1 {
t.Errorf("#%d: term = %d, want %d", i, g, 1)
}
assert.Equal(t, tt.state, r.state, "#%d", i)
assert.Equal(t, uint64(1), r.Term, "#%d", i)
}
}

Expand All @@ -255,13 +234,9 @@ func TestFollowerVote(t *testing.T) {

r.Step(pb.Message{From: tt.nvote, To: 1, Term: 1, Type: pb.MsgVote})

msgs := r.msgsAfterAppend
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: tt.nvote, Term: 1, Type: pb.MsgVoteResp, Reject: tt.wreject},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("#%d: msgs = %v, want %v", i, msgs, wmsgs)
}
}, r.msgsAfterAppend, "#%d", i)
}
}

Expand All @@ -278,18 +253,12 @@ func TestCandidateFallback(t *testing.T) {
for i, tt := range tests {
r := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3)))
r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgHup})
if r.state != StateCandidate {
t.Fatalf("unexpected state = %s, want %s", r.state, StateCandidate)
}
require.Equal(t, StateCandidate, r.state, "#%d", i)

r.Step(tt)

if g := r.state; g != StateFollower {
t.Errorf("#%d: state = %s, want %s", i, g, StateFollower)
}
if g := r.Term; g != tt.Term {
t.Errorf("#%d: term = %d, want %d", i, g, tt.Term)
}
assert.Equal(t, StateFollower, r.state, "#%d", i)
assert.Equal(t, tt.Term, r.Term, "#%d", i)
}
}

Expand Down Expand Up @@ -328,9 +297,7 @@ func testNonleaderElectionTimeoutRandomized(t *testing.T, state StateType) {
}

for d := et; d < 2*et; d++ {
if !timeouts[d] {
t.Errorf("timeout in %d ticks should happen", d)
}
assert.True(t, timeouts[d], "timeout in %d ticks should happen", d)
}
}

Expand Down Expand Up @@ -383,9 +350,7 @@ func testNonleadersElectionTimeoutNonconflict(t *testing.T, state StateType) {
}
}

if g := float64(conflicts) / 1000; g > 0.3 {
t.Errorf("probability of conflicts = %v, want <= 0.3", g)
}
assert.LessOrEqual(t, float64(conflicts)/1000, 0.3)
}

// TestLeaderStartReplication tests that when receiving client proposals,
Expand All @@ -407,25 +372,18 @@ func TestLeaderStartReplication(t *testing.T) {
ents := []pb.Entry{{Data: []byte("some data")}}
r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: ents})

if g := r.raftLog.lastIndex(); g != li+1 {
t.Errorf("lastIndex = %d, want %d", g, li+1)
}
if g := r.raftLog.committed; g != li {
t.Errorf("committed = %d, want %d", g, li)
}
assert.Equal(t, li+1, r.raftLog.lastIndex())
assert.Equal(t, li, r.raftLog.committed)
msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
wents := []pb.Entry{{Index: li + 1, Term: 1, Data: []byte("some data")}}
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Term: 1, Type: pb.MsgApp, Index: li, LogTerm: 1, Entries: wents, Commit: li},
{From: 1, To: 3, Term: 1, Type: pb.MsgApp, Index: li, LogTerm: 1, Entries: wents, Commit: li},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("msgs = %+v, want %+v", msgs, wmsgs)
}
if g := r.raftLog.nextUnstableEnts(); !reflect.DeepEqual(g, wents) {
t.Errorf("ents = %+v, want %+v", g, wents)
}
}, msgs)
assert.Equal(t, []pb.Entry{
{Index: li + 1, Term: 1, Data: []byte("some data")},
}, r.raftLog.nextUnstableEnts())
}

// TestLeaderCommitEntry tests that when the entry has been safely replicated,
Expand All @@ -448,25 +406,16 @@ func TestLeaderCommitEntry(t *testing.T) {
r.Step(acceptAndReply(m))
}

if g := r.raftLog.committed; g != li+1 {
t.Errorf("committed = %d, want %d", g, li+1)
}
wents := []pb.Entry{{Index: li + 1, Term: 1, Data: []byte("some data")}}
if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) {
t.Errorf("nextCommittedEnts = %+v, want %+v", g, wents)
}
assert.Equal(t, li+1, r.raftLog.committed)
assert.Equal(t, []pb.Entry{
{Index: li + 1, Term: 1, Data: []byte("some data")},
}, r.raftLog.nextCommittedEnts(true))
msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
for i, m := range msgs {
if w := uint64(i + 2); m.To != w {
t.Errorf("to = %x, want %x", m.To, w)
}
if m.Type != pb.MsgApp {
t.Errorf("type = %v, want %v", m.Type, pb.MsgApp)
}
if m.Commit != li+1 {
t.Errorf("commit = %d, want %d", m.Commit, li+1)
}
assert.Equal(t, uint64(i+2), m.To)
assert.Equal(t, pb.MsgApp, m.Type)
assert.Equal(t, li+1, m.Commit)
}
}

Expand Down Expand Up @@ -504,9 +453,7 @@ func TestLeaderAcknowledgeCommit(t *testing.T) {
}
}

if g := r.raftLog.committed > li; g != tt.wack {
t.Errorf("#%d: ack commit = %v, want %v", i, g, tt.wack)
}
assert.Equal(t, tt.wack, r.raftLog.committed > li, "#%d", i)
}
}

Expand Down Expand Up @@ -536,10 +483,10 @@ func TestLeaderCommitPrecedingEntries(t *testing.T) {
}

li := uint64(len(tt))
wents := append(tt, pb.Entry{Term: 3, Index: li + 1}, pb.Entry{Term: 3, Index: li + 2, Data: []byte("some data")})
if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) {
t.Errorf("#%d: ents = %+v, want %+v", i, g, wents)
}
assert.Equal(t, append(tt,
pb.Entry{Term: 3, Index: li + 1},
pb.Entry{Term: 3, Index: li + 2, Data: []byte("some data")},
), r.raftLog.nextCommittedEnts(true), "#%d", i)
}
}

Expand Down Expand Up @@ -585,13 +532,8 @@ func TestFollowerCommitEntry(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 1, Entries: tt.ents, Commit: tt.commit})

if g := r.raftLog.committed; g != tt.commit {
t.Errorf("#%d: committed = %d, want %d", i, g, tt.commit)
}
wents := tt.ents[:int(tt.commit)]
if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) {
t.Errorf("#%d: nextCommittedEnts = %v, want %v", i, g, wents)
}
assert.Equal(t, tt.commit, r.raftLog.committed, "#%d", i)
assert.Equal(t, tt.ents[:int(tt.commit)], r.raftLog.nextCommittedEnts(true), "#%d", i)
}
}

Expand Down Expand Up @@ -630,13 +572,9 @@ func TestFollowerCheckMsgApp(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 2, LogTerm: tt.term, Index: tt.index})

msgs := r.readMessages()
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.windex, Reject: tt.wreject, RejectHint: tt.wrejectHint, LogTerm: tt.wlogterm},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("#%d: msgs = %+v, want %+v", i, msgs, wmsgs)
}
}, r.readMessages(), "#%d", i)
}
}

Expand Down Expand Up @@ -685,12 +623,8 @@ func TestFollowerAppendEntries(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 2, LogTerm: tt.term, Index: tt.index, Entries: tt.ents})

if g := r.raftLog.allEntries(); !reflect.DeepEqual(g, tt.wents) {
t.Errorf("#%d: ents = %+v, want %+v", i, g, tt.wents)
}
if g := r.raftLog.nextUnstableEnts(); !reflect.DeepEqual(g, tt.wunstable) {
t.Errorf("#%d: unstableEnts = %+v, want %+v", i, g, tt.wunstable)
}
assert.Equal(t, tt.wents, r.raftLog.allEntries(), "#%d", i)
assert.Equal(t, tt.wunstable, r.raftLog.nextUnstableEnts(), "#%d", i)
}
}

Expand Down Expand Up @@ -727,9 +661,7 @@ func TestLeaderSyncFollowerLog(t *testing.T) {

n.send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{}}})

if g := diffu(ltoa(lead.raftLog), ltoa(follower.raftLog)); g != "" {
t.Errorf("#%d: log diff:\n%s", i, g)
}
assert.Empty(t, diffu(ltoa(lead.raftLog), ltoa(follower.raftLog)), "#%d", i)
}
}

Expand Down Expand Up @@ -757,26 +689,14 @@ func TestVoteRequest(t *testing.T) {

msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
if len(msgs) != 2 {
t.Fatalf("#%d: len(msg) = %d, want %d", j, len(msgs), 2)
}
require.Len(t, msgs, 2, "#%d", j)
for i, m := range msgs {
if m.Type != pb.MsgVote {
t.Errorf("#%d: msgType = %d, want %d", i, m.Type, pb.MsgVote)
}
if m.To != uint64(i+2) {
t.Errorf("#%d: to = %d, want %d", i, m.To, i+2)
}
if m.Term != tt.wterm {
t.Errorf("#%d: term = %d, want %d", i, m.Term, tt.wterm)
}
windex, wlogterm := tt.ents[len(tt.ents)-1].Index, tt.ents[len(tt.ents)-1].Term
if m.Index != windex {
t.Errorf("#%d: index = %d, want %d", i, m.Index, windex)
}
if m.LogTerm != wlogterm {
t.Errorf("#%d: logterm = %d, want %d", i, m.LogTerm, wlogterm)
}
assert.Equal(t, pb.MsgVote, m.Type, "#%d.%d", j, i)
assert.Equal(t, uint64(i+2), m.To, "#%d.%d", j, i)
assert.Equal(t, tt.wterm, m.Term, "#%d.%d", j, i)

assert.Equal(t, tt.ents[len(tt.ents)-1].Index, m.Index, "#%d.%d", j, i)
assert.Equal(t, tt.ents[len(tt.ents)-1].Term, m.LogTerm, "#%d.%d", j, i)
}
}
}
Expand Down Expand Up @@ -814,16 +734,10 @@ func TestVoter(t *testing.T) {
r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgVote, Term: 3, LogTerm: tt.logterm, Index: tt.index})

msgs := r.readMessages()
if len(msgs) != 1 {
t.Fatalf("#%d: len(msg) = %d, want %d", i, len(msgs), 1)
}
require.Len(t, msgs, 1, "#%d", i)
m := msgs[0]
if m.Type != pb.MsgVoteResp {
t.Errorf("#%d: msgType = %d, want %d", i, m.Type, pb.MsgVoteResp)
}
if m.Reject != tt.wreject {
t.Errorf("#%d: reject = %t, want %t", i, m.Reject, tt.wreject)
}
assert.Equal(t, pb.MsgVoteResp, m.Type, "#%d", i)
assert.Equal(t, tt.wreject, m.Reject, "#%d", i)
}
}

Expand Down Expand Up @@ -856,9 +770,7 @@ func TestLeaderOnlyCommitsLogFromCurrentTerm(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Term: r.Term, Index: tt.index})
r.advanceMessagesAfterAppend()
if r.raftLog.committed != tt.wcommit {
t.Errorf("#%d: commit = %d, want %d", i, r.raftLog.committed, tt.wcommit)
}
assert.Equal(t, tt.wcommit, r.raftLog.committed, "#%d", i)
}
}

Expand Down
Loading