Skip to content

Commit

Permalink
Fix concurrent map access in connmgr (#1860)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Nov 9, 2022
1 parent 9a18c47 commit d6e725b
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 41 deletions.
75 changes: 53 additions & 22 deletions p2p/net/connmgr/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,21 @@ type segment struct {
peers map[peer.ID]*peerInfo
}

type segments [256]*segment
type segments struct {
// bucketsMu is used to prevent deadlocks when concurrent processes try to
// grab multiple segment locks at once. If you need multiple segment locks
// at once, you should grab this lock first. You may release this lock once
// you have the segment locks.
bucketsMu sync.Mutex
buckets [256]*segment
}

func (ss *segments) get(p peer.ID) *segment {
return ss[byte(p[len(p)-1])]
return ss.buckets[byte(p[len(p)-1])]
}

func (ss *segments) countPeers() (count int) {
for _, seg := range ss {
for _, seg := range ss.buckets {
seg.Lock()
count += len(seg.peers)
seg.Unlock()
Expand Down Expand Up @@ -122,15 +129,15 @@ func NewConnManager(low, hi int, opts ...Option) (*BasicConnMgr, error) {
cfg: cfg,
clock: cfg.clock,
protected: make(map[peer.ID]map[string]struct{}, 16),
segments: func() (ret segments) {
for i := range ret {
ret[i] = &segment{
peers: make(map[peer.ID]*peerInfo),
}
}
return ret
}(),
segments: segments{},
}

for i := range cm.segments.buckets {
cm.segments.buckets[i] = &segment{
peers: make(map[peer.ID]*peerInfo),
}
}

cm.ctx, cm.cancel = context.WithCancel(context.Background())

if cfg.emergencyTrim {
Expand Down Expand Up @@ -246,15 +253,32 @@ type peerInfo struct {
firstSeen time.Time // timestamp when we began tracking this peer.
}

type peerInfos []peerInfo
type peerInfos []*peerInfo

// SortByValueAndStreams sorts peerInfos by their value and stream count. It
// will sort peers with no streams before those with streams (all else being
// equal). If `sortByMoreStreams` is true it will sort peers with more streams
// before those with fewer streams. This is useful to prioritize freeing memory.
func (p peerInfos) SortByValueAndStreams(sortByMoreStreams bool) {
func (p peerInfos) SortByValueAndStreams(segments *segments, sortByMoreStreams bool) {
sort.Slice(p, func(i, j int) bool {
left, right := p[i], p[j]

// Grab this lock so that we can grab both segment locks below without deadlocking.
segments.bucketsMu.Lock()

// lock this to protect from concurrent modifications from connect/disconnect events
leftSegment := segments.get(left.id)
leftSegment.Lock()
defer leftSegment.Unlock()

rightSegment := segments.get(right.id)
if leftSegment != rightSegment {
// These two peers are not in the same segment, lets get the lock
rightSegment.Lock()
defer rightSegment.Unlock()
}
segments.bucketsMu.Unlock()

// temporary peers are preferred for pruning.
if left.temp != right.temp {
return left.temp
Expand Down Expand Up @@ -360,31 +384,34 @@ func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn {
candidates := make(peerInfos, 0, cm.segments.countPeers())

cm.plk.RLock()
for _, s := range cm.segments {
for _, s := range cm.segments.buckets {
s.Lock()
for id, inf := range s.peers {
if _, ok := cm.protected[id]; ok {
// skip over protected peer.
continue
}
candidates = append(candidates, *inf)
candidates = append(candidates, inf)
}
s.Unlock()
}
cm.plk.RUnlock()

// Sort peers according to their value.
candidates.SortByValueAndStreams(true)
candidates.SortByValueAndStreams(&cm.segments, true)

selected := make([]network.Conn, 0, target+10)
for _, inf := range candidates {
if target <= 0 {
break
}
s := cm.segments.get(inf.id)
s.Lock()
for c := range inf.conns {
selected = append(selected, c)
}
target -= len(inf.conns)
s.Unlock()
}
if len(selected) >= target {
// We found enough connections that were not protected.
Expand All @@ -395,24 +422,28 @@ func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn {
// We have no choice but to kill some protected connections.
candidates = candidates[:0]
cm.plk.RLock()
for _, s := range cm.segments {
for _, s := range cm.segments.buckets {
s.Lock()
for _, inf := range s.peers {
candidates = append(candidates, *inf)
candidates = append(candidates, inf)
}
s.Unlock()
}
cm.plk.RUnlock()

candidates.SortByValueAndStreams(true)
candidates.SortByValueAndStreams(&cm.segments, true)
for _, inf := range candidates {
if target <= 0 {
break
}
// lock this to protect from concurrent modifications from connect/disconnect events
s := cm.segments.get(inf.id)
s.Lock()
for c := range inf.conns {
selected = append(selected, c)
}
target -= len(inf.conns)
s.Unlock()
}
return selected
}
Expand All @@ -435,7 +466,7 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn {
gracePeriodStart := cm.clock.Now().Add(-cm.cfg.gracePeriod)

cm.plk.RLock()
for _, s := range cm.segments {
for _, s := range cm.segments.buckets {
s.Lock()
for id, inf := range s.peers {
if _, ok := cm.protected[id]; ok {
Expand All @@ -448,7 +479,7 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn {
}
// note that we're copying the entry here,
// but since inf.conns is a map, it will still point to the original object
candidates = append(candidates, *inf)
candidates = append(candidates, inf)
ncandidates += len(inf.conns)
}
s.Unlock()
Expand All @@ -465,7 +496,7 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn {
}

// Sort peers according to their value.
candidates.SortByValueAndStreams(false)
candidates.SortByValueAndStreams(&cm.segments, false)

target := ncandidates - cm.cfg.lowWater

Expand Down
97 changes: 80 additions & 17 deletions p2p/net/connmgr/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,36 +807,57 @@ func (m mockConn) NewStream(ctx context.Context) (network.Stream, error) { panic
func (m mockConn) GetStreams() []network.Stream { panic("implement me") }
func (m mockConn) Scope() network.ConnScope { panic("implement me") }

func makeSegmentsWithPeerInfos(peerInfos peerInfos) *segments {
var s = func() *segments {
ret := segments{}
for i := range ret.buckets {
ret.buckets[i] = &segment{
peers: make(map[peer.ID]*peerInfo),
}
}
return &ret
}()

for _, pi := range peerInfos {
segment := s.get(pi.id)
segment.Lock()
segment.peers[pi.id] = pi
segment.Unlock()
}

return s
}

func TestPeerInfoSorting(t *testing.T) {
t.Run("starts with temporary connections", func(t *testing.T) {
p1 := peerInfo{id: peer.ID("peer1")}
p2 := peerInfo{id: peer.ID("peer2"), temp: true}
p1 := &peerInfo{id: peer.ID("peer1")}
p2 := &peerInfo{id: peer.ID("peer2"), temp: true}
pis := peerInfos{p1, p2}
pis.SortByValueAndStreams(false)
pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), false)
require.Equal(t, pis, peerInfos{p2, p1})
})

t.Run("starts with low-value connections", func(t *testing.T) {
p1 := peerInfo{id: peer.ID("peer1"), value: 40}
p2 := peerInfo{id: peer.ID("peer2"), value: 20}
p1 := &peerInfo{id: peer.ID("peer1"), value: 40}
p2 := &peerInfo{id: peer.ID("peer2"), value: 20}
pis := peerInfos{p1, p2}
pis.SortByValueAndStreams(false)
pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), false)
require.Equal(t, pis, peerInfos{p2, p1})
})

t.Run("prefer peers with no streams", func(t *testing.T) {
p1 := peerInfo{id: peer.ID("peer1"),
p1 := &peerInfo{id: peer.ID("peer1"),
conns: map[network.Conn]time.Time{
&mockConn{stats: network.ConnStats{NumStreams: 0}}: time.Now(),
},
}
p2 := peerInfo{id: peer.ID("peer2"),
p2 := &peerInfo{id: peer.ID("peer2"),
conns: map[network.Conn]time.Time{
&mockConn{stats: network.ConnStats{NumStreams: 1}}: time.Now(),
},
}
pis := peerInfos{p2, p1}
pis.SortByValueAndStreams(false)
pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), false)
require.Equal(t, pis, peerInfos{p1, p2})
})

Expand All @@ -848,57 +869,99 @@ func TestPeerInfoSorting(t *testing.T) {

outgoingSomeStreams := network.ConnStats{Stats: network.Stats{Direction: network.DirOutbound}, NumStreams: 1}
outgoingMoreStreams := network.ConnStats{Stats: network.Stats{Direction: network.DirOutbound}, NumStreams: 2}
p1 := peerInfo{
p1 := &peerInfo{
id: peer.ID("peer1"),
conns: map[network.Conn]time.Time{
&mockConn{stats: outgoingSomeStreams}: time.Now(),
},
}
p2 := peerInfo{
p2 := &peerInfo{
id: peer.ID("peer2"),
conns: map[network.Conn]time.Time{
&mockConn{stats: outgoingSomeStreams}: time.Now(),
&mockConn{stats: incoming}: time.Now(),
},
}
p3 := peerInfo{
p3 := &peerInfo{
id: peer.ID("peer3"),
conns: map[network.Conn]time.Time{
&mockConn{stats: outgoing}: time.Now(),
&mockConn{stats: incoming}: time.Now(),
},
}
p4 := peerInfo{
p4 := &peerInfo{
id: peer.ID("peer4"),
conns: map[network.Conn]time.Time{
&mockConn{stats: outgoingMoreStreams}: time.Now(),
&mockConn{stats: incoming}: time.Now(),
},
}
pis := peerInfos{p1, p2, p3, p4}
pis.SortByValueAndStreams(true)
pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), true)
// p3 is first because it is inactive (no streams).
// p4 is second because it has the most streams and we priortize killing
// connections with the higher number of streams.
require.Equal(t, pis, peerInfos{p3, p4, p2, p1})
})

t.Run("in a memory emergency, starts with connections that have many streams", func(t *testing.T) {
p1 := peerInfo{
p1 := &peerInfo{
id: peer.ID("peer1"),
conns: map[network.Conn]time.Time{
&mockConn{stats: network.ConnStats{NumStreams: 100}}: time.Now(),
},
}
p2 := peerInfo{
p2 := &peerInfo{
id: peer.ID("peer2"),
conns: map[network.Conn]time.Time{
&mockConn{stats: network.ConnStats{NumStreams: 80}}: time.Now(),
&mockConn{stats: network.ConnStats{NumStreams: 40}}: time.Now(),
},
}
pis := peerInfos{p1, p2}
pis.SortByValueAndStreams(true)
pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), true)
require.Equal(t, pis, peerInfos{p2, p1})
})
}

func TestSafeConcurrency(t *testing.T) {
t.Run("Safe Concurrency", func(t *testing.T) {
cl := clock.NewMock()

p1 := &peerInfo{id: peer.ID("peer1"), conns: map[network.Conn]time.Time{}}
p2 := &peerInfo{id: peer.ID("peer2"), conns: map[network.Conn]time.Time{}}
pis := peerInfos{p1, p2}

ss := makeSegmentsWithPeerInfos(pis)

const runs = 10
const concurrency = 10
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
// add conns. This mimics new connection events
pis := peerInfos{p1, p2}
for i := 0; i < runs; i++ {
pi := pis[i%len(pis)]
s := ss.get(pi.id)
s.Lock()
s.peers[pi.id].conns[randConn(t, nil)] = cl.Now()
s.Unlock()
}
wg.Done()
}()

wg.Add(1)
go func() {
pis := peerInfos{p1, p2}
for i := 0; i < runs; i++ {
pis.SortByValueAndStreams(ss, false)
}
wg.Done()
}()
}

wg.Wait()
})
}
4 changes: 2 additions & 2 deletions p2p/net/connmgr/decay.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (d *decayer) process() {
d.tagsMu.Unlock()

// Visit each peer, and decay tags that need to be decayed.
for _, s := range d.mgr.segments {
for _, s := range d.mgr.segments.buckets {
s.Lock()

// Entered a segment that contains peers. Process each peer.
Expand Down Expand Up @@ -261,7 +261,7 @@ func (d *decayer) process() {
d.tagsMu.Unlock()

// Remove the tag from all peers that had it in the connmgr.
for _, s := range d.mgr.segments {
for _, s := range d.mgr.segments.buckets {
// visit all segments, and attempt to remove the tag from all the peers it stores.
s.Lock()
for _, p := range s.peers {
Expand Down

0 comments on commit d6e725b

Please sign in to comment.