diff --git a/p2p/net/connmgr/connmgr.go b/p2p/net/connmgr/connmgr.go index 3a1acf0db0..c124c6bb81 100644 --- a/p2p/net/connmgr/connmgr.go +++ b/p2p/net/connmgr/connmgr.go @@ -62,14 +62,21 @@ type segment struct { peers map[peer.ID]*peerInfo } -type segments [256]*segment +type segments struct { + // bucketsMu is used to prevent deadlocks when concurrent processes try to + // grab multiple segment locks at once. If you need multiple segment locks + // at once, you should grab this lock first. You may release this lock once + // you have the segment locks. + bucketsMu sync.Mutex + buckets [256]*segment +} func (ss *segments) get(p peer.ID) *segment { - return ss[byte(p[len(p)-1])] + return ss.buckets[byte(p[len(p)-1])] } func (ss *segments) countPeers() (count int) { - for _, seg := range ss { + for _, seg := range ss.buckets { seg.Lock() count += len(seg.peers) seg.Unlock() @@ -122,15 +129,15 @@ func NewConnManager(low, hi int, opts ...Option) (*BasicConnMgr, error) { cfg: cfg, clock: cfg.clock, protected: make(map[peer.ID]map[string]struct{}, 16), - segments: func() (ret segments) { - for i := range ret { - ret[i] = &segment{ - peers: make(map[peer.ID]*peerInfo), - } - } - return ret - }(), + segments: segments{}, + } + + for i := range cm.segments.buckets { + cm.segments.buckets[i] = &segment{ + peers: make(map[peer.ID]*peerInfo), + } } + cm.ctx, cm.cancel = context.WithCancel(context.Background()) if cfg.emergencyTrim { @@ -246,15 +253,32 @@ type peerInfo struct { firstSeen time.Time // timestamp when we began tracking this peer. } -type peerInfos []peerInfo +type peerInfos []*peerInfo // SortByValueAndStreams sorts peerInfos by their value and stream count. It // will sort peers with no streams before those with streams (all else being // equal). If `sortByMoreStreams` is true it will sort peers with more streams // before those with fewer streams. This is useful to prioritize freeing memory. -func (p peerInfos) SortByValueAndStreams(sortByMoreStreams bool) { +func (p peerInfos) SortByValueAndStreams(segments *segments, sortByMoreStreams bool) { sort.Slice(p, func(i, j int) bool { left, right := p[i], p[j] + + // Grab this lock so that we can grab both segment locks below without deadlocking. + segments.bucketsMu.Lock() + + // lock this to protect from concurrent modifications from connect/disconnect events + leftSegment := segments.get(left.id) + leftSegment.Lock() + defer leftSegment.Unlock() + + rightSegment := segments.get(right.id) + if leftSegment != rightSegment { + // These two peers are not in the same segment, lets get the lock + rightSegment.Lock() + defer rightSegment.Unlock() + } + segments.bucketsMu.Unlock() + // temporary peers are preferred for pruning. if left.temp != right.temp { return left.temp @@ -360,31 +384,34 @@ func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn { candidates := make(peerInfos, 0, cm.segments.countPeers()) cm.plk.RLock() - for _, s := range cm.segments { + for _, s := range cm.segments.buckets { s.Lock() for id, inf := range s.peers { if _, ok := cm.protected[id]; ok { // skip over protected peer. continue } - candidates = append(candidates, *inf) + candidates = append(candidates, inf) } s.Unlock() } cm.plk.RUnlock() // Sort peers according to their value. - candidates.SortByValueAndStreams(true) + candidates.SortByValueAndStreams(&cm.segments, true) selected := make([]network.Conn, 0, target+10) for _, inf := range candidates { if target <= 0 { break } + s := cm.segments.get(inf.id) + s.Lock() for c := range inf.conns { selected = append(selected, c) } target -= len(inf.conns) + s.Unlock() } if len(selected) >= target { // We found enough connections that were not protected. @@ -395,24 +422,28 @@ func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn { // We have no choice but to kill some protected connections. candidates = candidates[:0] cm.plk.RLock() - for _, s := range cm.segments { + for _, s := range cm.segments.buckets { s.Lock() for _, inf := range s.peers { - candidates = append(candidates, *inf) + candidates = append(candidates, inf) } s.Unlock() } cm.plk.RUnlock() - candidates.SortByValueAndStreams(true) + candidates.SortByValueAndStreams(&cm.segments, true) for _, inf := range candidates { if target <= 0 { break } + // lock this to protect from concurrent modifications from connect/disconnect events + s := cm.segments.get(inf.id) + s.Lock() for c := range inf.conns { selected = append(selected, c) } target -= len(inf.conns) + s.Unlock() } return selected } @@ -435,7 +466,7 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn { gracePeriodStart := cm.clock.Now().Add(-cm.cfg.gracePeriod) cm.plk.RLock() - for _, s := range cm.segments { + for _, s := range cm.segments.buckets { s.Lock() for id, inf := range s.peers { if _, ok := cm.protected[id]; ok { @@ -448,7 +479,7 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn { } // note that we're copying the entry here, // but since inf.conns is a map, it will still point to the original object - candidates = append(candidates, *inf) + candidates = append(candidates, inf) ncandidates += len(inf.conns) } s.Unlock() @@ -465,7 +496,7 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn { } // Sort peers according to their value. - candidates.SortByValueAndStreams(false) + candidates.SortByValueAndStreams(&cm.segments, false) target := ncandidates - cm.cfg.lowWater diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 2053e3e6f7..39db7b54cf 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -808,36 +808,57 @@ func (m mockConn) GetStreams() []network.Stream { panic func (m mockConn) Scope() network.ConnScope { panic("implement me") } func (m mockConn) ConnState() network.ConnectionState { return network.ConnectionState{} } +func makeSegmentsWithPeerInfos(peerInfos peerInfos) *segments { + var s = func() *segments { + ret := segments{} + for i := range ret.buckets { + ret.buckets[i] = &segment{ + peers: make(map[peer.ID]*peerInfo), + } + } + return &ret + }() + + for _, pi := range peerInfos { + segment := s.get(pi.id) + segment.Lock() + segment.peers[pi.id] = pi + segment.Unlock() + } + + return s +} + func TestPeerInfoSorting(t *testing.T) { t.Run("starts with temporary connections", func(t *testing.T) { - p1 := peerInfo{id: peer.ID("peer1")} - p2 := peerInfo{id: peer.ID("peer2"), temp: true} + p1 := &peerInfo{id: peer.ID("peer1")} + p2 := &peerInfo{id: peer.ID("peer2"), temp: true} pis := peerInfos{p1, p2} - pis.SortByValueAndStreams(false) + pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), false) require.Equal(t, pis, peerInfos{p2, p1}) }) t.Run("starts with low-value connections", func(t *testing.T) { - p1 := peerInfo{id: peer.ID("peer1"), value: 40} - p2 := peerInfo{id: peer.ID("peer2"), value: 20} + p1 := &peerInfo{id: peer.ID("peer1"), value: 40} + p2 := &peerInfo{id: peer.ID("peer2"), value: 20} pis := peerInfos{p1, p2} - pis.SortByValueAndStreams(false) + pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), false) require.Equal(t, pis, peerInfos{p2, p1}) }) t.Run("prefer peers with no streams", func(t *testing.T) { - p1 := peerInfo{id: peer.ID("peer1"), + p1 := &peerInfo{id: peer.ID("peer1"), conns: map[network.Conn]time.Time{ &mockConn{stats: network.ConnStats{NumStreams: 0}}: time.Now(), }, } - p2 := peerInfo{id: peer.ID("peer2"), + p2 := &peerInfo{id: peer.ID("peer2"), conns: map[network.Conn]time.Time{ &mockConn{stats: network.ConnStats{NumStreams: 1}}: time.Now(), }, } pis := peerInfos{p2, p1} - pis.SortByValueAndStreams(false) + pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), false) require.Equal(t, pis, peerInfos{p1, p2}) }) @@ -849,27 +870,27 @@ func TestPeerInfoSorting(t *testing.T) { outgoingSomeStreams := network.ConnStats{Stats: network.Stats{Direction: network.DirOutbound}, NumStreams: 1} outgoingMoreStreams := network.ConnStats{Stats: network.Stats{Direction: network.DirOutbound}, NumStreams: 2} - p1 := peerInfo{ + p1 := &peerInfo{ id: peer.ID("peer1"), conns: map[network.Conn]time.Time{ &mockConn{stats: outgoingSomeStreams}: time.Now(), }, } - p2 := peerInfo{ + p2 := &peerInfo{ id: peer.ID("peer2"), conns: map[network.Conn]time.Time{ &mockConn{stats: outgoingSomeStreams}: time.Now(), &mockConn{stats: incoming}: time.Now(), }, } - p3 := peerInfo{ + p3 := &peerInfo{ id: peer.ID("peer3"), conns: map[network.Conn]time.Time{ &mockConn{stats: outgoing}: time.Now(), &mockConn{stats: incoming}: time.Now(), }, } - p4 := peerInfo{ + p4 := &peerInfo{ id: peer.ID("peer4"), conns: map[network.Conn]time.Time{ &mockConn{stats: outgoingMoreStreams}: time.Now(), @@ -877,7 +898,7 @@ func TestPeerInfoSorting(t *testing.T) { }, } pis := peerInfos{p1, p2, p3, p4} - pis.SortByValueAndStreams(true) + pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), true) // p3 is first because it is inactive (no streams). // p4 is second because it has the most streams and we priortize killing // connections with the higher number of streams. @@ -885,13 +906,13 @@ func TestPeerInfoSorting(t *testing.T) { }) t.Run("in a memory emergency, starts with connections that have many streams", func(t *testing.T) { - p1 := peerInfo{ + p1 := &peerInfo{ id: peer.ID("peer1"), conns: map[network.Conn]time.Time{ &mockConn{stats: network.ConnStats{NumStreams: 100}}: time.Now(), }, } - p2 := peerInfo{ + p2 := &peerInfo{ id: peer.ID("peer2"), conns: map[network.Conn]time.Time{ &mockConn{stats: network.ConnStats{NumStreams: 80}}: time.Now(), @@ -899,7 +920,49 @@ func TestPeerInfoSorting(t *testing.T) { }, } pis := peerInfos{p1, p2} - pis.SortByValueAndStreams(true) + pis.SortByValueAndStreams(makeSegmentsWithPeerInfos(pis), true) require.Equal(t, pis, peerInfos{p2, p1}) }) } + +func TestSafeConcurrency(t *testing.T) { + t.Run("Safe Concurrency", func(t *testing.T) { + cl := clock.NewMock() + + p1 := &peerInfo{id: peer.ID("peer1"), conns: map[network.Conn]time.Time{}} + p2 := &peerInfo{id: peer.ID("peer2"), conns: map[network.Conn]time.Time{}} + pis := peerInfos{p1, p2} + + ss := makeSegmentsWithPeerInfos(pis) + + const runs = 10 + const concurrency = 10 + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + // add conns. This mimics new connection events + pis := peerInfos{p1, p2} + for i := 0; i < runs; i++ { + pi := pis[i%len(pis)] + s := ss.get(pi.id) + s.Lock() + s.peers[pi.id].conns[randConn(t, nil)] = cl.Now() + s.Unlock() + } + wg.Done() + }() + + wg.Add(1) + go func() { + pis := peerInfos{p1, p2} + for i := 0; i < runs; i++ { + pis.SortByValueAndStreams(ss, false) + } + wg.Done() + }() + } + + wg.Wait() + }) +} diff --git a/p2p/net/connmgr/decay.go b/p2p/net/connmgr/decay.go index 0819bd2136..c10214cb8a 100644 --- a/p2p/net/connmgr/decay.go +++ b/p2p/net/connmgr/decay.go @@ -177,7 +177,7 @@ func (d *decayer) process() { d.tagsMu.Unlock() // Visit each peer, and decay tags that need to be decayed. - for _, s := range d.mgr.segments { + for _, s := range d.mgr.segments.buckets { s.Lock() // Entered a segment that contains peers. Process each peer. @@ -261,7 +261,7 @@ func (d *decayer) process() { d.tagsMu.Unlock() // Remove the tag from all peers that had it in the connmgr. - for _, s := range d.mgr.segments { + for _, s := range d.mgr.segments.buckets { // visit all segments, and attempt to remove the tag from all the peers it stores. s.Lock() for _, p := range s.peers {