diff --git a/bitswap.go b/bitswap.go index 3abbc197..28c1589b 100644 --- a/bitswap.go +++ b/bitswap.go @@ -97,8 +97,8 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, return nil }) - peerQueueFactory := func(p peer.ID) bspm.PeerQueue { - return bsmq.New(p, network) + peerQueueFactory := func(ctx context.Context, p peer.ID) bspm.PeerQueue { + return bsmq.New(ctx, p, network) } wm := bswm.New(ctx) diff --git a/decision/peer_request_queue.go b/decision/peer_request_queue.go index c7aaf553..0fa78c8a 100644 --- a/decision/peer_request_queue.go +++ b/decision/peer_request_queue.go @@ -60,7 +60,7 @@ func (tl *prq) Push(to peer.ID, entries ...*wantlist.Entry) { defer partner.activelk.Unlock() var priority int - newEntries := make([]*wantlist.Entry, 0, len(entries)) + newEntries := make([]*peerRequestTaskEntry, 0, len(entries)) for _, entry := range entries { if partner.activeBlocks.Has(entry.Cid) { continue @@ -75,7 +75,7 @@ func (tl *prq) Push(to peer.ID, entries ...*wantlist.Entry) { if entry.Priority > priority { priority = entry.Priority } - newEntries = append(newEntries, entry) + newEntries = append(newEntries, &peerRequestTaskEntry{entry, false}) } if len(newEntries) == 0 { @@ -86,7 +86,7 @@ func (tl *prq) Push(to peer.ID, entries ...*wantlist.Entry) { Entries: newEntries, Target: to, created: time.Now(), - Done: func(e []*wantlist.Entry) { + Done: func(e []*peerRequestTaskEntry) { tl.lock.Lock() for _, entry := range e { partner.TaskDone(entry.Cid) @@ -117,10 +117,10 @@ func (tl *prq) Pop() *peerRequestTask { for partner.taskQueue.Len() > 0 && partner.freezeVal == 0 { out = partner.taskQueue.Pop().(*peerRequestTask) - newEntries := make([]*wantlist.Entry, 0, len(out.Entries)) + newEntries := make([]*peerRequestTaskEntry, 0, len(out.Entries)) for _, entry := range out.Entries { delete(tl.taskMap, taskEntryKey{out.Target, entry.Cid}) - if entry.Trash { + if entry.trash { continue } partner.requests-- @@ -150,7 +150,7 @@ func (tl *prq) Remove(k cid.Cid, p peer.ID) { // remove the task "lazily" // simply mark it as trash, so it'll be dropped when popped off the // queue. - entry.Trash = true + entry.trash = true break } } @@ -197,13 +197,18 @@ func (tl *prq) thawRound() { } } +type peerRequestTaskEntry struct { + *wantlist.Entry + // trash in a book-keeping field + trash bool +} type peerRequestTask struct { - Entries []*wantlist.Entry + Entries []*peerRequestTaskEntry Priority int Target peer.ID // A callback to signal that this task has been completed - Done func([]*wantlist.Entry) + Done func([]*peerRequestTaskEntry) // created marks the time that the task was added to the queue created time.Time diff --git a/messagequeue/messagequeue.go b/messagequeue/messagequeue.go index 6d2cd1ce..e9204652 100644 --- a/messagequeue/messagequeue.go +++ b/messagequeue/messagequeue.go @@ -2,7 +2,6 @@ package messagequeue import ( "context" - "sync" "time" bsmsg "github.com/ipfs/go-bitswap/message" @@ -23,68 +22,72 @@ type MessageNetwork interface { NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error) } +type request interface { + handle(mq *MessageQueue) +} + // MessageQueue implements queue of want messages to send to peers. type MessageQueue struct { - p peer.ID - - outlk sync.Mutex - out bsmsg.BitSwapMessage + ctx context.Context + p peer.ID network MessageNetwork - wl *wantlist.ThreadSafe - sender bsnet.MessageSender + newRequests chan request + outgoingMessages chan bsmsg.BitSwapMessage + done chan struct{} + + // do not touch out of run loop + wl *wantlist.SessionTrackedWantlist + nextMessage bsmsg.BitSwapMessage + sender bsnet.MessageSender +} + +type messageRequest struct { + entries []*bsmsg.Entry + ses uint64 +} - work chan struct{} - done chan struct{} +type wantlistRequest struct { + wl *wantlist.SessionTrackedWantlist } // New creats a new MessageQueue. -func New(p peer.ID, network MessageNetwork) *MessageQueue { +func New(ctx context.Context, p peer.ID, network MessageNetwork) *MessageQueue { return &MessageQueue{ - done: make(chan struct{}), - work: make(chan struct{}, 1), - wl: wantlist.NewThreadSafe(), - network: network, - p: p, + ctx: ctx, + wl: wantlist.NewSessionTrackedWantlist(), + network: network, + p: p, + newRequests: make(chan request, 16), + outgoingMessages: make(chan bsmsg.BitSwapMessage), + done: make(chan struct{}), } } // AddMessage adds new entries to an outgoing message for a given session. func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) { - if !mq.addEntries(entries, ses) { - return - } select { - case mq.work <- struct{}{}: - default: + case mq.newRequests <- &messageRequest{entries, ses}: + case <-mq.ctx.Done(): } } // AddWantlist adds a complete session tracked want list to a message queue -func (mq *MessageQueue) AddWantlist(initialEntries []*wantlist.Entry) { - if len(initialEntries) > 0 { - if mq.out == nil { - mq.out = bsmsg.New(false) - } +func (mq *MessageQueue) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) { + wl := wantlist.NewSessionTrackedWantlist() + initialWants.CopyWants(wl) - for _, e := range initialEntries { - for k := range e.SesTrk { - mq.wl.AddEntry(e, k) - } - mq.out.AddEntry(e.Cid, e.Priority) - } - - select { - case mq.work <- struct{}{}: - default: - } + select { + case mq.newRequests <- &wantlistRequest{wl}: + case <-mq.ctx.Done(): } } // Startup starts the processing of messages, and creates an initial message // based on the given initial wantlist. -func (mq *MessageQueue) Startup(ctx context.Context) { - go mq.runQueue(ctx) +func (mq *MessageQueue) Startup() { + go mq.runQueue() + go mq.sendMessages() } // Shutdown stops the processing of messages for a message queue. @@ -92,17 +95,26 @@ func (mq *MessageQueue) Shutdown() { close(mq.done) } -func (mq *MessageQueue) runQueue(ctx context.Context) { +func (mq *MessageQueue) runQueue() { + outgoingMessages := func() chan bsmsg.BitSwapMessage { + if mq.nextMessage == nil { + return nil + } + return mq.outgoingMessages + } + for { select { - case <-mq.work: // there is work to be done - mq.doWork(ctx) + case newRequest := <-mq.newRequests: + newRequest.handle(mq) + case outgoingMessages() <- mq.nextMessage: + mq.nextMessage = nil case <-mq.done: if mq.sender != nil { mq.sender.Close() } return - case <-ctx.Done(): + case <-mq.ctx.Done(): if mq.sender != nil { mq.sender.Reset() } @@ -111,63 +123,77 @@ func (mq *MessageQueue) runQueue(ctx context.Context) { } } -func (mq *MessageQueue) addEntries(entries []*bsmsg.Entry, ses uint64) bool { - var work bool - mq.outlk.Lock() - defer mq.outlk.Unlock() - // if we have no message held allocate a new one - if mq.out == nil { - mq.out = bsmsg.New(false) +func (mr *messageRequest) handle(mq *MessageQueue) { + mq.addEntries(mr.entries, mr.ses) +} + +func (wr *wantlistRequest) handle(mq *MessageQueue) { + initialWants := wr.wl + initialWants.CopyWants(mq.wl) + if initialWants.Len() > 0 { + if mq.nextMessage == nil { + mq.nextMessage = bsmsg.New(false) + } + for _, e := range initialWants.Entries() { + mq.nextMessage.AddEntry(e.Cid, e.Priority) + } } +} - // TODO: add a msg.Combine(...) method - // otherwise, combine the one we are holding with the - // one passed in +func (mq *MessageQueue) addEntries(entries []*bsmsg.Entry, ses uint64) { for _, e := range entries { if e.Cancel { if mq.wl.Remove(e.Cid, ses) { - work = true - mq.out.Cancel(e.Cid) + if mq.nextMessage == nil { + mq.nextMessage = bsmsg.New(false) + } + mq.nextMessage.Cancel(e.Cid) } } else { if mq.wl.Add(e.Cid, e.Priority, ses) { - work = true - mq.out.AddEntry(e.Cid, e.Priority) + if mq.nextMessage == nil { + mq.nextMessage = bsmsg.New(false) + } + mq.nextMessage.AddEntry(e.Cid, e.Priority) } } } - - return work } -func (mq *MessageQueue) doWork(ctx context.Context) { - - wlm := mq.extractOutgoingMessage() - if wlm == nil || wlm.Empty() { - return +func (mq *MessageQueue) sendMessages() { + for { + select { + case nextMessage := <-mq.outgoingMessages: + mq.sendMessage(nextMessage) + case <-mq.done: + return + case <-mq.ctx.Done(): + return + } } +} + +func (mq *MessageQueue) sendMessage(message bsmsg.BitSwapMessage) { - // NB: only open a stream if we actually have data to send - err := mq.initializeSender(ctx) + err := mq.initializeSender() if err != nil { log.Infof("cant open message sender to peer %s: %s", mq.p, err) // TODO: cant connect, what now? return } - // send wantlist updates for i := 0; i < maxRetries; i++ { // try to send this message until we fail. - if mq.attemptSendAndRecovery(ctx, wlm) { + if mq.attemptSendAndRecovery(message) { return } } } -func (mq *MessageQueue) initializeSender(ctx context.Context) error { +func (mq *MessageQueue) initializeSender() error { if mq.sender != nil { return nil } - nsender, err := openSender(ctx, mq.network, mq.p) + nsender, err := openSender(mq.ctx, mq.network, mq.p) if err != nil { return err } @@ -175,8 +201,8 @@ func (mq *MessageQueue) initializeSender(ctx context.Context) error { return nil } -func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.BitSwapMessage) bool { - err := mq.sender.SendMsg(ctx, wlm) +func (mq *MessageQueue) attemptSendAndRecovery(message bsmsg.BitSwapMessage) bool { + err := mq.sender.SendMsg(mq.ctx, message) if err == nil { return true } @@ -188,14 +214,14 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi select { case <-mq.done: return true - case <-ctx.Done(): + case <-mq.ctx.Done(): return true case <-time.After(time.Millisecond * 100): // wait 100ms in case disconnect notifications are still propogating log.Warning("SendMsg errored but neither 'done' nor context.Done() were set") } - err = mq.initializeSender(ctx) + err = mq.initializeSender() if err != nil { log.Infof("couldnt open sender again after SendMsg(%s) failed: %s", mq.p, err) // TODO(why): what do we do now? @@ -215,15 +241,6 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi return false } -func (mq *MessageQueue) extractOutgoingMessage() bsmsg.BitSwapMessage { - // grab outgoing message - mq.outlk.Lock() - wlm := mq.out - mq.out = nil - mq.outlk.Unlock() - return wlm -} - func openSender(ctx context.Context, network MessageNetwork, p peer.ID) (bsnet.MessageSender, error) { // allow ten minutes for connections this includes looking them up in the // dht dialing them, and handshaking diff --git a/messagequeue/messagequeue_test.go b/messagequeue/messagequeue_test.go index b780678d..aeb903dd 100644 --- a/messagequeue/messagequeue_test.go +++ b/messagequeue/messagequeue_test.go @@ -27,7 +27,6 @@ func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID) (bsnet return fmn.messageSender, nil } return nil, fmn.messageSenderError - } type fakeMessageSender struct { @@ -77,12 +76,12 @@ func TestStartupAndShutdown(t *testing.T) { fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] - messageQueue := New(peerID, fakenet) + messageQueue := New(ctx, peerID, fakenet) ses := testutil.GenerateSessionID() wl := testutil.GenerateWantlist(10, ses) - messageQueue.Startup(ctx) - messageQueue.AddWantlist(wl.Entries()) + messageQueue.Startup() + messageQueue.AddWantlist(wl) messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) if len(messages) != 1 { t.Fatal("wrong number of messages were sent for initial wants") @@ -119,11 +118,11 @@ func TestSendingMessagesDeduped(t *testing.T) { fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] - messageQueue := New(peerID, fakenet) + messageQueue := New(ctx, peerID, fakenet) ses1 := testutil.GenerateSessionID() ses2 := testutil.GenerateSessionID() entries := testutil.GenerateMessageEntries(10, false) - messageQueue.Startup(ctx) + messageQueue.Startup() messageQueue.AddMessage(entries, ses1) messageQueue.AddMessage(entries, ses2) @@ -142,13 +141,13 @@ func TestSendingMessagesPartialDupe(t *testing.T) { fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] - messageQueue := New(peerID, fakenet) + messageQueue := New(ctx, peerID, fakenet) ses1 := testutil.GenerateSessionID() ses2 := testutil.GenerateSessionID() entries := testutil.GenerateMessageEntries(10, false) moreEntries := testutil.GenerateMessageEntries(5, false) secondEntries := append(entries[5:], moreEntries...) - messageQueue.Startup(ctx) + messageQueue.Startup() messageQueue.AddMessage(entries, ses1) messageQueue.AddMessage(secondEntries, ses2) diff --git a/peermanager/peermanager.go b/peermanager/peermanager.go index 48c8de43..b1b8ee9a 100644 --- a/peermanager/peermanager.go +++ b/peermanager/peermanager.go @@ -20,13 +20,13 @@ var ( // PeerQueue provides a queer of messages to be sent for a single peer. type PeerQueue interface { AddMessage(entries []*bsmsg.Entry, ses uint64) - Startup(ctx context.Context) - AddWantlist(initialEntries []*wantlist.Entry) + Startup() + AddWantlist(initialWants *wantlist.SessionTrackedWantlist) Shutdown() } // PeerQueueFactory provides a function that will create a PeerQueue. -type PeerQueueFactory func(p peer.ID) PeerQueue +type PeerQueueFactory func(ctx context.Context, p peer.ID) PeerQueue type peerMessage interface { handle(pm *PeerManager) @@ -69,13 +69,13 @@ func (pm *PeerManager) ConnectedPeers() []peer.ID { // Connected is called to add a new peer to the pool, and send it an initial set // of wants. -func (pm *PeerManager) Connected(p peer.ID, initialEntries []*wantlist.Entry) { +func (pm *PeerManager) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) { pm.peerQueuesLk.Lock() pq := pm.getOrCreate(p) if pq.refcnt == 0 { - pq.pq.AddWantlist(initialEntries) + pq.pq.AddWantlist(initialWants) } pq.refcnt++ @@ -128,8 +128,8 @@ func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, fr func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance { pqi, ok := pm.peerQueues[p] if !ok { - pq := pm.createPeerQueue(p) - pq.Startup(pm.ctx) + pq := pm.createPeerQueue(pm.ctx, p) + pq.Startup() pqi = &peerQueueInstance{0, pq} pm.peerQueues[p] = pqi } diff --git a/peermanager/peermanager_test.go b/peermanager/peermanager_test.go index ac8595d5..1d56d042 100644 --- a/peermanager/peermanager_test.go +++ b/peermanager/peermanager_test.go @@ -24,15 +24,15 @@ type fakePeer struct { messagesSent chan messageSent } -func (fp *fakePeer) Startup(ctx context.Context) {} -func (fp *fakePeer) Shutdown() {} +func (fp *fakePeer) Startup() {} +func (fp *fakePeer) Shutdown() {} func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) { fp.messagesSent <- messageSent{fp.p, entries, ses} } -func (fp *fakePeer) AddWantlist(initialEntries []*wantlist.Entry) {} +func (fp *fakePeer) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) {} func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory { - return func(p peer.ID) PeerQueue { + return func(ctx context.Context, p peer.ID) PeerQueue { return &fakePeer{ p: p, messagesSent: messagesSent, diff --git a/testutil/testutil.go b/testutil/testutil.go index 3d799666..05fd152b 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -39,8 +39,8 @@ func GenerateCids(n int) []cid.Cid { } // GenerateWantlist makes a populated wantlist. -func GenerateWantlist(n int, ses uint64) *wantlist.ThreadSafe { - wl := wantlist.NewThreadSafe() +func GenerateWantlist(n int, ses uint64) *wantlist.SessionTrackedWantlist { + wl := wantlist.NewSessionTrackedWantlist() for i := 0; i < n; i++ { prioritySeq++ entry := wantlist.NewRefEntry(blockGenerator.Next().Cid(), prioritySeq) diff --git a/wantlist/wantlist.go b/wantlist/wantlist.go index 947c964d..1da4ed97 100644 --- a/wantlist/wantlist.go +++ b/wantlist/wantlist.go @@ -1,20 +1,17 @@ -// package wantlist implements an object for bitswap that contains the keys +// Package wantlist implements an object for bitswap that contains the keys // that a given peer wants. package wantlist import ( "sort" - "sync" cid "github.com/ipfs/go-cid" ) -type ThreadSafe struct { - lk sync.RWMutex - set map[cid.Cid]*Entry +type SessionTrackedWantlist struct { + set map[cid.Cid]*sessionTrackedEntry } -// not threadsafe type Wantlist struct { set map[cid.Cid]*Entry } @@ -22,10 +19,11 @@ type Wantlist struct { type Entry struct { Cid cid.Cid Priority int +} - SesTrk map[uint64]struct{} - // Trash in a book-keeping field - Trash bool +type sessionTrackedEntry struct { + *Entry + sesTrk map[uint64]struct{} } // NewRefEntry creates a new reference tracked wantlist entry. @@ -33,7 +31,6 @@ func NewRefEntry(c cid.Cid, p int) *Entry { return &Entry{ Cid: c, Priority: p, - SesTrk: make(map[uint64]struct{}), } } @@ -43,9 +40,9 @@ func (es entrySlice) Len() int { return len(es) } func (es entrySlice) Swap(i, j int) { es[i], es[j] = es[j], es[i] } func (es entrySlice) Less(i, j int) bool { return es[i].Priority > es[j].Priority } -func NewThreadSafe() *ThreadSafe { - return &ThreadSafe{ - set: make(map[cid.Cid]*Entry), +func NewSessionTrackedWantlist() *SessionTrackedWantlist { + return &SessionTrackedWantlist{ + set: make(map[cid.Cid]*sessionTrackedEntry), } } @@ -63,33 +60,31 @@ func New() *Wantlist { // TODO: think through priority changes here // Add returns true if the cid did not exist in the wantlist before this call // (even if it was under a different session). -func (w *ThreadSafe) Add(c cid.Cid, priority int, ses uint64) bool { - w.lk.Lock() - defer w.lk.Unlock() +func (w *SessionTrackedWantlist) Add(c cid.Cid, priority int, ses uint64) bool { + if e, ok := w.set[c]; ok { - e.SesTrk[ses] = struct{}{} + e.sesTrk[ses] = struct{}{} return false } - w.set[c] = &Entry{ - Cid: c, - Priority: priority, - SesTrk: map[uint64]struct{}{ses: struct{}{}}, + w.set[c] = &sessionTrackedEntry{ + Entry: &Entry{Cid: c, Priority: priority}, + sesTrk: map[uint64]struct{}{ses: struct{}{}}, } return true } // AddEntry adds given Entry to the wantlist. For more information see Add method. -func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool { - w.lk.Lock() - defer w.lk.Unlock() +func (w *SessionTrackedWantlist) AddEntry(e *Entry, ses uint64) bool { if ex, ok := w.set[e.Cid]; ok { - ex.SesTrk[ses] = struct{}{} + ex.sesTrk[ses] = struct{}{} return false } - w.set[e.Cid] = e - e.SesTrk[ses] = struct{}{} + w.set[e.Cid] = &sessionTrackedEntry{ + Entry: e, + sesTrk: map[uint64]struct{}{ses: struct{}{}}, + } return true } @@ -97,16 +92,14 @@ func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool { // 'true' is returned if this call to Remove removed the final session ID // tracking the cid. (meaning true will be returned iff this call caused the // value of 'Contains(c)' to change from true to false) -func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool { - w.lk.Lock() - defer w.lk.Unlock() +func (w *SessionTrackedWantlist) Remove(c cid.Cid, ses uint64) bool { e, ok := w.set[c] if !ok { return false } - delete(e.SesTrk, ses) - if len(e.SesTrk) == 0 { + delete(e.sesTrk, ses) + if len(e.sesTrk) == 0 { delete(w.set, c) return true } @@ -115,35 +108,40 @@ func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool { // Contains returns true if the given cid is in the wantlist tracked by one or // more sessions. -func (w *ThreadSafe) Contains(k cid.Cid) (*Entry, bool) { - w.lk.RLock() - defer w.lk.RUnlock() +func (w *SessionTrackedWantlist) Contains(k cid.Cid) (*Entry, bool) { e, ok := w.set[k] - return e, ok + if !ok { + return nil, false + } + return e.Entry, true } -func (w *ThreadSafe) Entries() []*Entry { - w.lk.RLock() - defer w.lk.RUnlock() +func (w *SessionTrackedWantlist) Entries() []*Entry { es := make([]*Entry, 0, len(w.set)) for _, e := range w.set { - es = append(es, e) + es = append(es, e.Entry) } return es } -func (w *ThreadSafe) SortedEntries() []*Entry { +func (w *SessionTrackedWantlist) SortedEntries() []*Entry { es := w.Entries() sort.Sort(entrySlice(es)) return es } -func (w *ThreadSafe) Len() int { - w.lk.RLock() - defer w.lk.RUnlock() +func (w *SessionTrackedWantlist) Len() int { return len(w.set) } +func (w *SessionTrackedWantlist) CopyWants(to *SessionTrackedWantlist) { + for _, e := range w.set { + for k := range e.sesTrk { + to.AddEntry(e.Entry, k) + } + } +} + func (w *Wantlist) Len() int { return len(w.set) } diff --git a/wantlist/wantlist_test.go b/wantlist/wantlist_test.go index 4ce31949..d11f6b7f 100644 --- a/wantlist/wantlist_test.go +++ b/wantlist/wantlist_test.go @@ -82,8 +82,8 @@ func TestBasicWantlist(t *testing.T) { } } -func TestSesRefWantlist(t *testing.T) { - wl := NewThreadSafe() +func TestSessionTrackedWantlist(t *testing.T) { + wl := NewSessionTrackedWantlist() if !wl.Add(testcids[0], 5, 1) { t.Fatal("should have added") diff --git a/wantmanager/wantmanager.go b/wantmanager/wantmanager.go index 57bd65f8..17f76bb2 100644 --- a/wantmanager/wantmanager.go +++ b/wantmanager/wantmanager.go @@ -24,7 +24,7 @@ const ( // managed by the WantManager. type PeerHandler interface { Disconnected(p peer.ID) - Connected(p peer.ID, initialEntries []*wantlist.Entry) + Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64) } @@ -42,8 +42,8 @@ type WantManager struct { wantMessages chan wantMessage // synchronized by Run loop, only touch inside there - wl *wantlist.ThreadSafe - bcwl *wantlist.ThreadSafe + wl *wantlist.SessionTrackedWantlist + bcwl *wantlist.SessionTrackedWantlist ctx context.Context cancel func() @@ -59,8 +59,8 @@ func New(ctx context.Context) *WantManager { "Number of items in wantlist.").Gauge() return &WantManager{ wantMessages: make(chan wantMessage, 10), - wl: wantlist.NewThreadSafe(), - bcwl: wantlist.NewThreadSafe(), + wl: wantlist.NewSessionTrackedWantlist(), + bcwl: wantlist.NewSessionTrackedWantlist(), ctx: ctx, cancel: cancel, wantlistGauge: wantlistGauge, @@ -274,7 +274,7 @@ type connectedMessage struct { } func (cm *connectedMessage) handle(wm *WantManager) { - wm.peerHandler.Connected(cm.p, wm.bcwl.Entries()) + wm.peerHandler.Connected(cm.p, wm.bcwl) } type disconnectedMessage struct { diff --git a/wantmanager/wantmanager_test.go b/wantmanager/wantmanager_test.go index 46d1d0b3..4cb05ac0 100644 --- a/wantmanager/wantmanager_test.go +++ b/wantmanager/wantmanager_test.go @@ -25,8 +25,8 @@ func (fph *fakePeerHandler) SendMessage(entries []*bsmsg.Entry, targets []peer.I fph.lk.Unlock() } -func (fph *fakePeerHandler) Connected(p peer.ID, initialEntries []*wantlist.Entry) {} -func (fph *fakePeerHandler) Disconnected(p peer.ID) {} +func (fph *fakePeerHandler) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) {} +func (fph *fakePeerHandler) Disconnected(p peer.ID) {} func (fph *fakePeerHandler) getLastWantSet() wantSet { fph.lk.Lock()