From 5db627fe21da6f7355756ed402c676f5507ee9e3 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Tue, 22 Jan 2019 17:55:05 -0800 Subject: [PATCH 1/9] feat(bitswap): Add a ProvideQueryManager Add a manger for querying providers on blocks, in charge of managing requests, deduping, and rate limiting --- providerquerymanager/providerquerymanager.go | 343 ++++++++++++++++++ .../providerquerymanager_test.go | 274 ++++++++++++++ 2 files changed, 617 insertions(+) create mode 100644 providerquerymanager/providerquerymanager.go create mode 100644 providerquerymanager/providerquerymanager_test.go diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go new file mode 100644 index 00000000..49075a20 --- /dev/null +++ b/providerquerymanager/providerquerymanager.go @@ -0,0 +1,343 @@ +package providerquerymanager + +import ( + "context" + "sync" + + "github.com/ipfs/go-cid" + logging "github.com/ipfs/go-log" + peer "github.com/libp2p/go-libp2p-peer" +) + +var log = logging.Logger("bitswap") + +const ( + maxProviders = 10 + maxInProcessRequests = 6 +) + +type inProgressRequestStatus struct { + providersSoFar []peer.ID + listeners map[uint64]chan peer.ID +} + +// ProviderQueryNetwork is an interface for finding providers and connecting to +// peers. +type ProviderQueryNetwork interface { + ConnectTo(context.Context, peer.ID) error + FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID +} + +type providerQueryMessage interface { + handle(pqm *ProviderQueryManager) +} + +type receivedProviderMessage struct { + k cid.Cid + p peer.ID +} + +type finishedProviderQueryMessage struct { + k cid.Cid +} + +type newProvideQueryMessage struct { + ses uint64 + k cid.Cid + inProgressRequestChan chan<- inProgressRequest +} + +type cancelRequestMessage struct { + ses uint64 + k cid.Cid +} + +// ProviderQueryManager manages requests to find more providers for blocks +// for bitswap sessions. It's main goals are to: +// - rate limit requests -- don't have too many find provider calls running +// simultaneously +// - connect to found peers and filter them if it can't connect +// - ensure two findprovider calls for the same block don't run concurrently +// TODO: +// - manage timeouts +type ProviderQueryManager struct { + ctx context.Context + network ProviderQueryNetwork + providerQueryMessages chan providerQueryMessage + + // do not touch outside the run loop + providerRequestsProcessing chan cid.Cid + incomingFindProviderRequests chan cid.Cid + inProgressRequestStatuses map[cid.Cid]*inProgressRequestStatus +} + +// New initializes a new ProviderQueryManager for a given context and a given +// network provider. +func New(ctx context.Context, network ProviderQueryNetwork) *ProviderQueryManager { + return &ProviderQueryManager{ + ctx: ctx, + network: network, + providerQueryMessages: make(chan providerQueryMessage, 16), + providerRequestsProcessing: make(chan cid.Cid), + incomingFindProviderRequests: make(chan cid.Cid), + inProgressRequestStatuses: make(map[cid.Cid]*inProgressRequestStatus), + } +} + +// Startup starts processing for the ProviderQueryManager. +func (pqm *ProviderQueryManager) Startup() { + go pqm.run() +} + +type inProgressRequest struct { + providersSoFar []peer.ID + incoming <-chan peer.ID +} + +// FindProvidersAsync finds providers for the given block. +func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid, ses uint64) <-chan peer.ID { + inProgressRequestChan := make(chan inProgressRequest) + + select { + case pqm.providerQueryMessages <- &newProvideQueryMessage{ + ses: ses, + k: k, + inProgressRequestChan: inProgressRequestChan, + }: + case <-pqm.ctx.Done(): + return nil + case <-sessionCtx.Done(): + return nil + } + + var receivedInProgressRequest inProgressRequest + select { + case <-sessionCtx.Done(): + return nil + case receivedInProgressRequest = <-inProgressRequestChan: + } + + return pqm.receiveProviders(sessionCtx, k, ses, receivedInProgressRequest) +} + +func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, ses uint64, receivedInProgressRequest inProgressRequest) <-chan peer.ID { + // maintains an unbuffered queue for incoming providers for given request for a given session + // essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all + // sessions that queried that CID, without worrying about whether the client code is actually + // reading from the returned channel -- so that the broadcast never blocks + // based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd + returnedProviders := make(chan peer.ID) + receivedProviders := append([]peer.ID(nil), receivedInProgressRequest.providersSoFar[0:]...) + incomingProviders := receivedInProgressRequest.incoming + + go func() { + defer close(returnedProviders) + outgoingProviders := func() chan<- peer.ID { + if len(receivedProviders) == 0 { + return nil + } + return returnedProviders + } + nextProvider := func() peer.ID { + if len(receivedProviders) == 0 { + return "" + } + return receivedProviders[0] + } + for len(receivedProviders) > 0 || incomingProviders != nil { + select { + case <-sessionCtx.Done(): + pqm.providerQueryMessages <- &cancelRequestMessage{ + ses: ses, + k: k, + } + // clear out any remaining providers + for range incomingProviders { + } + return + case provider, ok := <-incomingProviders: + if !ok { + incomingProviders = nil + } else { + receivedProviders = append(receivedProviders, provider) + } + case outgoingProviders() <- nextProvider(): + receivedProviders = receivedProviders[1:] + } + } + }() + return returnedProviders +} + +func (pqm *ProviderQueryManager) findProviderWorker() { + // findProviderWorker just cycles through incoming provider queries one + // at a time. We have six of these workers running at once + // to let requests go in parallel but keep them rate limited + for { + select { + case k, ok := <-pqm.providerRequestsProcessing: + if !ok { + return + } + + providers := pqm.network.FindProvidersAsync(pqm.ctx, k, maxProviders) + wg := &sync.WaitGroup{} + for p := range providers { + wg.Add(1) + go func(p peer.ID) { + defer wg.Done() + err := pqm.network.ConnectTo(pqm.ctx, p) + if err != nil { + log.Debugf("failed to connect to provider %s: %s", p, err) + return + } + select { + case pqm.providerQueryMessages <- &receivedProviderMessage{ + k: k, + p: p, + }: + case <-pqm.ctx.Done(): + return + } + }(p) + } + wg.Wait() + select { + case pqm.providerQueryMessages <- &finishedProviderQueryMessage{ + k: k, + }: + case <-pqm.ctx.Done(): + } + case <-pqm.ctx.Done(): + return + } + } +} + +func (pqm *ProviderQueryManager) providerRequestBufferWorker() { + // the provider request buffer worker just maintains an unbounded + // buffer for incoming provider queries and dispatches to the find + // provider workers as they become available + // based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd + var providerQueryRequestBuffer []cid.Cid + nextProviderQuery := func() cid.Cid { + if len(providerQueryRequestBuffer) == 0 { + return cid.Cid{} + } + return providerQueryRequestBuffer[0] + } + outgoingRequests := func() chan<- cid.Cid { + if len(providerQueryRequestBuffer) == 0 { + return nil + } + return pqm.providerRequestsProcessing + } + + for { + select { + case incomingRequest, ok := <-pqm.incomingFindProviderRequests: + if !ok { + return + } + providerQueryRequestBuffer = append(providerQueryRequestBuffer, incomingRequest) + case outgoingRequests() <- nextProviderQuery(): + providerQueryRequestBuffer = providerQueryRequestBuffer[1:] + case <-pqm.ctx.Done(): + return + } + } +} + +func (pqm *ProviderQueryManager) cleanupInProcessRequests() { + for _, requestStatus := range pqm.inProgressRequestStatuses { + for _, listener := range requestStatus.listeners { + close(listener) + } + } +} + +func (pqm *ProviderQueryManager) run() { + defer close(pqm.incomingFindProviderRequests) + defer close(pqm.providerRequestsProcessing) + defer pqm.cleanupInProcessRequests() + + go pqm.providerRequestBufferWorker() + for i := 0; i < maxInProcessRequests; i++ { + go pqm.findProviderWorker() + } + + for { + select { + case nextMessage := <-pqm.providerQueryMessages: + nextMessage.handle(pqm) + case <-pqm.ctx.Done(): + return + } + } +} + +func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) { + requestStatus, ok := pqm.inProgressRequestStatuses[rpm.k] + if !ok { + log.Errorf("Received provider (%s) for cid (%s) not requested", rpm.p.String(), rpm.k.String()) + return + } + requestStatus.providersSoFar = append(requestStatus.providersSoFar, rpm.p) + for _, listener := range requestStatus.listeners { + select { + case listener <- rpm.p: + case <-pqm.ctx.Done(): + return + } + } +} + +func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) { + requestStatus, ok := pqm.inProgressRequestStatuses[fpqm.k] + if !ok { + log.Errorf("Ended request for cid (%s) not in progress", fpqm.k.String()) + return + } + for _, listener := range requestStatus.listeners { + close(listener) + } + delete(pqm.inProgressRequestStatuses, fpqm.k) +} + +func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { + requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k] + if !ok { + requestStatus = &inProgressRequestStatus{ + listeners: make(map[uint64]chan peer.ID), + } + pqm.inProgressRequestStatuses[npqm.k] = requestStatus + select { + case pqm.incomingFindProviderRequests <- npqm.k: + case <-pqm.ctx.Done(): + return + } + } + requestStatus.listeners[npqm.ses] = make(chan peer.ID) + select { + case npqm.inProgressRequestChan <- inProgressRequest{ + providersSoFar: requestStatus.providersSoFar, + incoming: requestStatus.listeners[npqm.ses], + }: + case <-pqm.ctx.Done(): + } +} + +func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { + requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] + if !ok { + log.Errorf("Attempt to cancel request for session (%d) for cid (%s) not in progress", crm.ses, crm.k.String()) + return + } + listener, ok := requestStatus.listeners[crm.ses] + if !ok { + log.Errorf("Attempt to cancel request for session (%d) for cid (%s) this is not a listener", crm.ses, crm.k.String()) + return + } + close(listener) + delete(requestStatus.listeners, crm.ses) +} diff --git a/providerquerymanager/providerquerymanager_test.go b/providerquerymanager/providerquerymanager_test.go new file mode 100644 index 00000000..68893198 --- /dev/null +++ b/providerquerymanager/providerquerymanager_test.go @@ -0,0 +1,274 @@ +package providerquerymanager + +import ( + "context" + "errors" + "reflect" + "testing" + "time" + + "github.com/ipfs/go-bitswap/testutil" + + cid "github.com/ipfs/go-cid" + "github.com/libp2p/go-libp2p-peer" +) + +type fakeProviderNetwork struct { + peersFound []peer.ID + connectError error + delay time.Duration + connectDelay time.Duration + queriesMade int +} + +func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error { + time.Sleep(fpn.connectDelay) + return fpn.connectError +} + +func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.ID { + fpn.queriesMade++ + incomingPeers := make(chan peer.ID) + go func() { + defer close(incomingPeers) + for _, p := range fpn.peersFound { + time.Sleep(fpn.delay) + select { + case incomingPeers <- p: + case <-ctx.Done(): + return + } + } + }() + return incomingPeers +} + +func TestNormalSimultaneousFetch(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + keys := testutil.GenerateCids(2) + sessionID1 := testutil.GenerateSessionID() + sessionID2 := testutil.GenerateSessionID() + + sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], sessionID1) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], sessionID2) + + var firstPeersReceived []peer.ID + for p := range firstRequestChan { + firstPeersReceived = append(firstPeersReceived, p) + } + + var secondPeersReceived []peer.ID + for p := range secondRequestChan { + secondPeersReceived = append(secondPeersReceived, p) + } + + if len(firstPeersReceived) != len(peers) || len(secondPeersReceived) != len(peers) { + t.Fatal("Did not collect all peers for request that was completed") + } + + if fpn.queriesMade != 2 { + t.Fatal("Did not dedup provider requests running simultaneously") + } +} + +func TestDedupingProviderRequests(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + key := testutil.GenerateCids(1)[0] + sessionID1 := testutil.GenerateSessionID() + sessionID2 := testutil.GenerateSessionID() + + sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer cancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) + + var firstPeersReceived []peer.ID + for p := range firstRequestChan { + firstPeersReceived = append(firstPeersReceived, p) + } + + var secondPeersReceived []peer.ID + for p := range secondRequestChan { + secondPeersReceived = append(secondPeersReceived, p) + } + + if len(firstPeersReceived) != len(peers) || len(secondPeersReceived) != len(peers) { + t.Fatal("Did not collect all peers for request that was completed") + } + + if !reflect.DeepEqual(firstPeersReceived, secondPeersReceived) { + t.Fatal("Did not receive the same response to both find provider requests") + } + + if fpn.queriesMade != 1 { + t.Fatal("Did not dedup provider requests running simultaneously") + } +} + +func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + + key := testutil.GenerateCids(1)[0] + sessionID1 := testutil.GenerateSessionID() + sessionID2 := testutil.GenerateSessionID() + + // first session will cancel before done + firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond) + defer firstCancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, sessionID1) + secondSessionCtx, secondCancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer secondCancel() + secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, sessionID2) + + var firstPeersReceived []peer.ID + for p := range firstRequestChan { + firstPeersReceived = append(firstPeersReceived, p) + } + + var secondPeersReceived []peer.ID + for p := range secondRequestChan { + secondPeersReceived = append(secondPeersReceived, p) + } + + if len(secondPeersReceived) != len(peers) { + t.Fatal("Did not collect all peers for request that was completed") + } + + if len(firstPeersReceived) >= len(peers) { + t.Fatal("Collected all peers on cancelled peer, should have been cancelled immediately") + } + + if fpn.queriesMade != 1 { + t.Fatal("Did not dedup provider requests running simultaneously") + } +} + +func TestCancelManagerExitsGracefully(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + managerCtx, managerCancel := context.WithTimeout(ctx, 5*time.Millisecond) + defer managerCancel() + providerQueryManager := New(managerCtx, fpn) + providerQueryManager.Startup() + + key := testutil.GenerateCids(1)[0] + sessionID1 := testutil.GenerateSessionID() + sessionID2 := testutil.GenerateSessionID() + + sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer cancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) + + var firstPeersReceived []peer.ID + for p := range firstRequestChan { + firstPeersReceived = append(firstPeersReceived, p) + } + + var secondPeersReceived []peer.ID + for p := range secondRequestChan { + secondPeersReceived = append(secondPeersReceived, p) + } + + if len(firstPeersReceived) <= 0 || + len(firstPeersReceived) >= len(peers) || + len(secondPeersReceived) <= 0 || + len(secondPeersReceived) >= len(peers) { + t.Fatal("Did not cancel requests in progress correctly") + } +} + +func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + connectError: errors.New("not able to connect"), + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + + key := testutil.GenerateCids(1)[0] + sessionID1 := testutil.GenerateSessionID() + sessionID2 := testutil.GenerateSessionID() + + sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer cancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) + + var firstPeersReceived []peer.ID + for p := range firstRequestChan { + firstPeersReceived = append(firstPeersReceived, p) + } + + var secondPeersReceived []peer.ID + for p := range secondRequestChan { + secondPeersReceived = append(secondPeersReceived, p) + } + + if len(firstPeersReceived) != 0 || len(secondPeersReceived) != 0 { + t.Fatal("Did not filter out peers with connection issues") + } + +} + +func TestRateLimitingRequests(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + + keys := testutil.GenerateCids(maxInProcessRequests + 1) + sessionID := testutil.GenerateSessionID() + sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer cancel() + var requestChannels []<-chan peer.ID + for i := 0; i < maxInProcessRequests+1; i++ { + requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], sessionID)) + } + time.Sleep(2 * time.Millisecond) + if fpn.queriesMade != maxInProcessRequests { + t.Fatal("Did not limit parallel requests to rate limit") + } + for i := 0; i < maxInProcessRequests+1; i++ { + for range requestChannels[i] { + } + } + + if fpn.queriesMade != maxInProcessRequests+1 { + t.Fatal("Did not make all seperate requests") + } +} From 1f2b49efe3f888ace93fd7ccf1b200a134627243 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Tue, 22 Jan 2019 18:18:29 -0800 Subject: [PATCH 2/9] feat(ProviderQueryManager): manage timeouts Add functionality to timeout find provider requests so they don't run forever --- providerquerymanager/providerquerymanager.go | 32 ++++++++++++++----- .../providerquerymanager_test.go | 26 +++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index 49075a20..d2ba9e72 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -3,6 +3,7 @@ package providerquerymanager import ( "context" "sync" + "time" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log" @@ -14,6 +15,7 @@ var log = logging.Logger("bitswap") const ( maxProviders = 10 maxInProcessRequests = 6 + defaultTimeout = 10 * time.Second ) type inProgressRequestStatus struct { @@ -58,17 +60,19 @@ type cancelRequestMessage struct { // simultaneously // - connect to found peers and filter them if it can't connect // - ensure two findprovider calls for the same block don't run concurrently -// TODO: // - manage timeouts type ProviderQueryManager struct { - ctx context.Context - network ProviderQueryNetwork - providerQueryMessages chan providerQueryMessage - - // do not touch outside the run loop + ctx context.Context + network ProviderQueryNetwork + providerQueryMessages chan providerQueryMessage providerRequestsProcessing chan cid.Cid incomingFindProviderRequests chan cid.Cid - inProgressRequestStatuses map[cid.Cid]*inProgressRequestStatus + + findProviderTimeout time.Duration + timeoutMutex sync.RWMutex + + // do not touch outside the run loop + inProgressRequestStatuses map[cid.Cid]*inProgressRequestStatus } // New initializes a new ProviderQueryManager for a given context and a given @@ -81,6 +85,7 @@ func New(ctx context.Context, network ProviderQueryNetwork) *ProviderQueryManage providerRequestsProcessing: make(chan cid.Cid), incomingFindProviderRequests: make(chan cid.Cid), inProgressRequestStatuses: make(map[cid.Cid]*inProgressRequestStatus), + findProviderTimeout: defaultTimeout, } } @@ -94,6 +99,13 @@ type inProgressRequest struct { incoming <-chan peer.ID } +// SetFindProviderTimeout changes the timeout for finding providers +func (pqm *ProviderQueryManager) SetFindProviderTimeout(findProviderTimeout time.Duration) { + pqm.timeoutMutex.Lock() + pqm.findProviderTimeout = findProviderTimeout + pqm.timeoutMutex.Unlock() +} + // FindProvidersAsync finds providers for the given block. func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid, ses uint64) <-chan peer.ID { inProgressRequestChan := make(chan inProgressRequest) @@ -180,7 +192,11 @@ func (pqm *ProviderQueryManager) findProviderWorker() { return } - providers := pqm.network.FindProvidersAsync(pqm.ctx, k, maxProviders) + pqm.timeoutMutex.RLock() + findProviderCtx, cancel := context.WithTimeout(pqm.ctx, pqm.findProviderTimeout) + pqm.timeoutMutex.RUnlock() + defer cancel() + providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders) wg := &sync.WaitGroup{} for p := range providers { wg.Add(1) diff --git a/providerquerymanager/providerquerymanager_test.go b/providerquerymanager/providerquerymanager_test.go index 68893198..f2e6f036 100644 --- a/providerquerymanager/providerquerymanager_test.go +++ b/providerquerymanager/providerquerymanager_test.go @@ -272,3 +272,29 @@ func TestRateLimitingRequests(t *testing.T) { t.Fatal("Did not make all seperate requests") } } + +func TestFindProviderTimeout(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + providerQueryManager.SetFindProviderTimeout(3 * time.Millisecond) + keys := testutil.GenerateCids(1) + sessionID1 := testutil.GenerateSessionID() + + sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], sessionID1) + var firstPeersReceived []peer.ID + for p := range firstRequestChan { + firstPeersReceived = append(firstPeersReceived, p) + } + if len(firstPeersReceived) <= 0 || + len(firstPeersReceived) >= len(peers) { + t.Fatal("Find provider request should have timed out, did not") + } +} From 843391e63fe3534c85f1c3fc4892b809fd850d72 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Tue, 22 Jan 2019 18:46:42 -0800 Subject: [PATCH 3/9] feat(ProviderQueryManager): integrate in sessions Integrate the ProviderQueryManager into the SessionPeerManager and bitswap in general re #52, re #49 --- bitswap.go | 10 +- sessionpeermanager/sessionpeermanager.go | 66 +++++----- sessionpeermanager/sessionpeermanager_test.go | 114 ++++++------------ 3 files changed, 74 insertions(+), 116 deletions(-) diff --git a/bitswap.go b/bitswap.go index c4b8e887..ee0c939f 100644 --- a/bitswap.go +++ b/bitswap.go @@ -18,6 +18,7 @@ import ( bsnet "github.com/ipfs/go-bitswap/network" notifications "github.com/ipfs/go-bitswap/notifications" bspm "github.com/ipfs/go-bitswap/peermanager" + bspqm "github.com/ipfs/go-bitswap/providerquerymanager" bssession "github.com/ipfs/go-bitswap/session" bssm "github.com/ipfs/go-bitswap/sessionmanager" bsspm "github.com/ipfs/go-bitswap/sessionpeermanager" @@ -105,11 +106,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, } wm := bswm.New(ctx) + pqm := bspqm.New(ctx, network) + sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter) bssm.Session { return bssession.New(ctx, id, wm, pm, srs) } sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager { - return bsspm.New(ctx, id, network) + return bsspm.New(ctx, id, network.ConnectionManager(), pqm) } sessionRequestSplitterFactory := func(ctx context.Context) bssession.RequestSplitter { return bssrs.New(ctx) @@ -125,6 +128,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, newBlocks: make(chan cid.Cid, HasBlockBufferSize), provideKeys: make(chan cid.Cid, provideKeysBufferSize), wm: wm, + pqm: pqm, pm: bspm.New(ctx, peerQueueFactory), sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory, sessionRequestSplitterFactory), counters: new(counters), @@ -136,6 +140,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, bs.wm.SetDelegate(bs.pm) bs.pm.Startup() bs.wm.Startup() + bs.pqm.Startup() network.SetDelegate(bs) // Start up bitswaps async worker routines @@ -161,6 +166,9 @@ type Bitswap struct { // the wantlist tracks global wants for bitswap wm *bswm.WantManager + // the provider query manager manages requests to find providers + pqm *bspqm.ProviderQueryManager + // the engine is the bit of logic that decides who to send which blocks to engine *decision.Engine diff --git a/sessionpeermanager/sessionpeermanager.go b/sessionpeermanager/sessionpeermanager.go index 225f1901..091e1c7e 100644 --- a/sessionpeermanager/sessionpeermanager.go +++ b/sessionpeermanager/sessionpeermanager.go @@ -8,7 +8,6 @@ import ( logging "github.com/ipfs/go-log" cid "github.com/ipfs/go-cid" - ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr" peer "github.com/libp2p/go-libp2p-peer" ) @@ -19,11 +18,15 @@ const ( reservePeers = 2 ) -// PeerNetwork is an interface for finding providers and managing connections -type PeerNetwork interface { - ConnectionManager() ifconnmgr.ConnManager - ConnectTo(context.Context, peer.ID) error - FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID +// PeerTagger is an interface for tagging peers with metadata +type PeerTagger interface { + TagPeer(peer.ID, string, int) + UntagPeer(p peer.ID, tag string) +} + +// PeerProviderFinder is an interface for finding providers +type PeerProviderFinder interface { + FindProvidersAsync(context.Context, cid.Cid, uint64) <-chan peer.ID } type peerMessage interface { @@ -33,9 +36,11 @@ type peerMessage interface { // SessionPeerManager tracks and manages peers for a session, and provides // the best ones to the session type SessionPeerManager struct { - ctx context.Context - network PeerNetwork - tag string + ctx context.Context + tagger PeerTagger + providerFinder PeerProviderFinder + tag string + id uint64 peerMessages chan peerMessage @@ -46,12 +51,14 @@ type SessionPeerManager struct { } // New creates a new SessionPeerManager -func New(ctx context.Context, id uint64, network PeerNetwork) *SessionPeerManager { +func New(ctx context.Context, id uint64, tagger PeerTagger, providerFinder PeerProviderFinder) *SessionPeerManager { spm := &SessionPeerManager{ - ctx: ctx, - network: network, - peerMessages: make(chan peerMessage, 16), - activePeers: make(map[peer.ID]bool), + id: id, + ctx: ctx, + tagger: tagger, + providerFinder: providerFinder, + peerMessages: make(chan peerMessage, 16), + activePeers: make(map[peer.ID]bool), } spm.tag = fmt.Sprint("bs-ses-", id) @@ -101,24 +108,13 @@ func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID { // providers for the given Cid func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) { go func(k cid.Cid) { - // TODO: have a task queue setup for this to: - // - rate limit - // - manage timeouts - // - ensure two 'findprovs' calls for the same block don't run concurrently - // - share peers between sessions based on interest set - for p := range spm.network.FindProvidersAsync(ctx, k, 10) { - go func(p peer.ID) { - // TODO: Also use context from spm. - err := spm.network.ConnectTo(ctx, p) - if err != nil { - log.Debugf("failed to connect to provider %s: %s", p, err) - } - select { - case spm.peerMessages <- &peerFoundMessage{p}: - case <-ctx.Done(): - case <-spm.ctx.Done(): - } - }(p) + for p := range spm.providerFinder.FindProvidersAsync(ctx, k, spm.id) { + + select { + case spm.peerMessages <- &peerFoundMessage{p}: + case <-ctx.Done(): + case <-spm.ctx.Done(): + } } }(c) } @@ -136,8 +132,7 @@ func (spm *SessionPeerManager) run(ctx context.Context) { } func (spm *SessionPeerManager) tagPeer(p peer.ID) { - cmgr := spm.network.ConnectionManager() - cmgr.TagPeer(p, spm.tag, 10) + spm.tagger.TagPeer(p, spm.tag, 10) } func (spm *SessionPeerManager) insertOptimizedPeer(p peer.ID) { @@ -223,8 +218,7 @@ func (prm *peerReqMessage) handle(spm *SessionPeerManager) { } func (spm *SessionPeerManager) handleShutdown() { - cmgr := spm.network.ConnectionManager() for p := range spm.activePeers { - cmgr.UntagPeer(p, spm.tag) + spm.tagger.UntagPeer(p, spm.tag) } } diff --git a/sessionpeermanager/sessionpeermanager_test.go b/sessionpeermanager/sessionpeermanager_test.go index 2ec38f0a..68862942 100644 --- a/sessionpeermanager/sessionpeermanager_test.go +++ b/sessionpeermanager/sessionpeermanager_test.go @@ -2,7 +2,6 @@ package sessionpeermanager import ( "context" - "errors" "math/rand" "sync" "testing" @@ -11,35 +10,19 @@ import ( "github.com/ipfs/go-bitswap/testutil" cid "github.com/ipfs/go-cid" - ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr" - inet "github.com/libp2p/go-libp2p-net" peer "github.com/libp2p/go-libp2p-peer" ) -type fakePeerNetwork struct { - peers []peer.ID - connManager ifconnmgr.ConnManager - completed chan struct{} - connect chan struct{} +type fakePeerProviderFinder struct { + peers []peer.ID + completed chan struct{} } -func (fpn *fakePeerNetwork) ConnectionManager() ifconnmgr.ConnManager { - return fpn.connManager -} - -func (fpn *fakePeerNetwork) ConnectTo(ctx context.Context, p peer.ID) error { - select { - case fpn.connect <- struct{}{}: - return nil - case <-ctx.Done(): - return errors.New("Timeout Occurred") - } -} - -func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, num int) <-chan peer.ID { +func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid, ses uint64) <-chan peer.ID { peerCh := make(chan peer.ID) go func() { - for _, p := range fpn.peers { + + for _, p := range fppf.peers { select { case peerCh <- p: case <-ctx.Done(): @@ -50,52 +33,48 @@ func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, n close(peerCh) select { - case fpn.completed <- struct{}{}: + case fppf.completed <- struct{}{}: case <-ctx.Done(): } }() return peerCh } -type fakeConnManager struct { +type fakePeerTagger struct { taggedPeers []peer.ID wait sync.WaitGroup } -func (fcm *fakeConnManager) TagPeer(p peer.ID, tag string, n int) { - fcm.wait.Add(1) - fcm.taggedPeers = append(fcm.taggedPeers, p) +func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) { + fpt.wait.Add(1) + fpt.taggedPeers = append(fpt.taggedPeers, p) } -func (fcm *fakeConnManager) UntagPeer(p peer.ID, tag string) { - defer fcm.wait.Done() - for i := 0; i < len(fcm.taggedPeers); i++ { - if fcm.taggedPeers[i] == p { - fcm.taggedPeers[i] = fcm.taggedPeers[len(fcm.taggedPeers)-1] - fcm.taggedPeers = fcm.taggedPeers[:len(fcm.taggedPeers)-1] +func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { + defer fpt.wait.Done() + + for i := 0; i < len(fpt.taggedPeers); i++ { + if fpt.taggedPeers[i] == p { + fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1] + fpt.taggedPeers = fpt.taggedPeers[:len(fpt.taggedPeers)-1] return } } } -func (*fakeConnManager) GetTagInfo(p peer.ID) *ifconnmgr.TagInfo { return nil } -func (*fakeConnManager) TrimOpenConns(ctx context.Context) {} -func (*fakeConnManager) Notifee() inet.Notifiee { return nil } - func TestFindingMorePeers(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() completed := make(chan struct{}) - connect := make(chan struct{}) peers := testutil.GeneratePeers(5) - fcm := &fakeConnManager{} - fpn := &fakePeerNetwork{peers, fcm, completed, connect} + fpt := &fakePeerTagger{} + fppf := &fakePeerProviderFinder{peers, completed} c := testutil.GenerateCids(1)[0] id := testutil.GenerateSessionID() - sessionPeerManager := New(ctx, id, fpn) + sessionPeerManager := New(ctx, id, fpt, fppf) findCtx, findCancel := context.WithTimeout(ctx, 10*time.Millisecond) defer findCancel() @@ -105,13 +84,6 @@ func TestFindingMorePeers(t *testing.T) { case <-findCtx.Done(): t.Fatal("Did not finish finding providers") } - for range peers { - select { - case <-connect: - case <-findCtx.Done(): - t.Fatal("Did not connect to peer") - } - } time.Sleep(2 * time.Millisecond) sessionPeers := sessionPeerManager.GetOptimizedPeers() @@ -123,7 +95,7 @@ func TestFindingMorePeers(t *testing.T) { t.Fatal("incorrect peer found through finding providers") } } - if len(fcm.taggedPeers) != len(peers) { + if len(fpt.taggedPeers) != len(peers) { t.Fatal("Peers were not tagged!") } } @@ -133,12 +105,12 @@ func TestRecordingReceivedBlocks(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() p := testutil.GeneratePeers(1)[0] - fcm := &fakeConnManager{} - fpn := &fakePeerNetwork{nil, fcm, nil, nil} + fpt := &fakePeerTagger{} + fppf := &fakePeerProviderFinder{} c := testutil.GenerateCids(1)[0] id := testutil.GenerateSessionID() - sessionPeerManager := New(ctx, id, fpn) + sessionPeerManager := New(ctx, id, fpt, fppf) sessionPeerManager.RecordPeerResponse(p, c) time.Sleep(10 * time.Millisecond) sessionPeers := sessionPeerManager.GetOptimizedPeers() @@ -148,7 +120,7 @@ func TestRecordingReceivedBlocks(t *testing.T) { if sessionPeers[0] != p { t.Fatal("incorrect peer added on receive") } - if len(fcm.taggedPeers) != 1 { + if len(fpt.taggedPeers) != 1 { t.Fatal("Peers was not tagged!") } } @@ -159,12 +131,11 @@ func TestOrderingPeers(t *testing.T) { defer cancel() peers := testutil.GeneratePeers(100) completed := make(chan struct{}) - connect := make(chan struct{}) - fcm := &fakeConnManager{} - fpn := &fakePeerNetwork{peers, fcm, completed, connect} + fpt := &fakePeerTagger{} + fppf := &fakePeerProviderFinder{peers, completed} c := testutil.GenerateCids(1) id := testutil.GenerateSessionID() - sessionPeerManager := New(ctx, id, fpn) + sessionPeerManager := New(ctx, id, fpt, fppf) // add all peers to session sessionPeerManager.FindMorePeers(ctx, c[0]) @@ -173,13 +144,6 @@ func TestOrderingPeers(t *testing.T) { case <-ctx.Done(): t.Fatal("Did not finish finding providers") } - for range peers { - select { - case <-connect: - case <-ctx.Done(): - t.Fatal("Did not connect to peer") - } - } time.Sleep(2 * time.Millisecond) // record broadcast @@ -237,13 +201,12 @@ func TestUntaggingPeers(t *testing.T) { defer cancel() peers := testutil.GeneratePeers(5) completed := make(chan struct{}) - connect := make(chan struct{}) - fcm := &fakeConnManager{} - fpn := &fakePeerNetwork{peers, fcm, completed, connect} + fpt := &fakePeerTagger{} + fppf := &fakePeerProviderFinder{peers, completed} c := testutil.GenerateCids(1)[0] id := testutil.GenerateSessionID() - sessionPeerManager := New(ctx, id, fpn) + sessionPeerManager := New(ctx, id, fpt, fppf) sessionPeerManager.FindMorePeers(ctx, c) select { @@ -251,22 +214,15 @@ func TestUntaggingPeers(t *testing.T) { case <-ctx.Done(): t.Fatal("Did not finish finding providers") } - for range peers { - select { - case <-connect: - case <-ctx.Done(): - t.Fatal("Did not connect to peer") - } - } time.Sleep(2 * time.Millisecond) - if len(fcm.taggedPeers) != len(peers) { + if len(fpt.taggedPeers) != len(peers) { t.Fatal("Peers were not tagged!") } <-ctx.Done() - fcm.wait.Wait() + fpt.wait.Wait() - if len(fcm.taggedPeers) != 0 { + if len(fpt.taggedPeers) != 0 { t.Fatal("Peers were not untagged!") } } From 1eb28a223413168af69fdf5499a12db0cecec7a7 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 23 Jan 2019 14:01:53 -0800 Subject: [PATCH 4/9] fix(ProviderQueryManager): fix test + add logging Add debug logging for the provider query manager and make tests more reliable --- providerquerymanager/providerquerymanager.go | 22 ++++++++- .../providerquerymanager_test.go | 48 +++++++++++++------ 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index d2ba9e72..21cfcd0d 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -2,6 +2,7 @@ package providerquerymanager import ( "context" + "fmt" "sync" "time" @@ -31,6 +32,7 @@ type ProviderQueryNetwork interface { } type providerQueryMessage interface { + debugMessage() string handle(pqm *ProviderQueryManager) } @@ -192,6 +194,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { return } + log.Debugf("Beginning Find Provider Request for cid: %s", k.String()) pqm.timeoutMutex.RLock() findProviderCtx, cancel := context.WithTimeout(pqm.ctx, pqm.findProviderTimeout) pqm.timeoutMutex.RUnlock() @@ -273,8 +276,6 @@ func (pqm *ProviderQueryManager) cleanupInProcessRequests() { } func (pqm *ProviderQueryManager) run() { - defer close(pqm.incomingFindProviderRequests) - defer close(pqm.providerRequestsProcessing) defer pqm.cleanupInProcessRequests() go pqm.providerRequestBufferWorker() @@ -285,6 +286,7 @@ func (pqm *ProviderQueryManager) run() { for { select { case nextMessage := <-pqm.providerQueryMessages: + log.Debug(nextMessage.debugMessage()) nextMessage.handle(pqm) case <-pqm.ctx.Done(): return @@ -292,6 +294,10 @@ func (pqm *ProviderQueryManager) run() { } } +func (rpm *receivedProviderMessage) debugMessage() string { + return fmt.Sprintf("Received provider (%s) for cid (%s)", rpm.p.String(), rpm.k.String()) +} + func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[rpm.k] if !ok { @@ -308,6 +314,10 @@ func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) { } } +func (fpqm *finishedProviderQueryMessage) debugMessage() string { + return fmt.Sprintf("Finished Provider Query on cid: %s", fpqm.k.String()) +} + func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[fpqm.k] if !ok { @@ -320,6 +330,10 @@ func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) { delete(pqm.inProgressRequestStatuses, fpqm.k) } +func (npqm *newProvideQueryMessage) debugMessage() string { + return fmt.Sprintf("New Provider Query on cid: %s from session: %d", npqm.k.String(), npqm.ses) +} + func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k] if !ok { @@ -343,6 +357,10 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { } } +func (crm *cancelRequestMessage) debugMessage() string { + return fmt.Sprintf("Cancel provider query on cid: %s from session: %d", crm.k.String(), crm.ses) +} + func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] if !ok { diff --git a/providerquerymanager/providerquerymanager_test.go b/providerquerymanager/providerquerymanager_test.go index f2e6f036..f5b6db1e 100644 --- a/providerquerymanager/providerquerymanager_test.go +++ b/providerquerymanager/providerquerymanager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "reflect" + "sync" "testing" "time" @@ -14,11 +15,12 @@ import ( ) type fakeProviderNetwork struct { - peersFound []peer.ID - connectError error - delay time.Duration - connectDelay time.Duration - queriesMade int + peersFound []peer.ID + connectError error + delay time.Duration + connectDelay time.Duration + queriesMadeMutex sync.RWMutex + queriesMade int } func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error { @@ -27,13 +29,20 @@ func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error { } func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.ID { + fpn.queriesMadeMutex.Lock() fpn.queriesMade++ + fpn.queriesMadeMutex.Unlock() incomingPeers := make(chan peer.ID) go func() { defer close(incomingPeers) for _, p := range fpn.peersFound { time.Sleep(fpn.delay) select { + case <-ctx.Done(): + return + default: + } + select { case incomingPeers <- p: case <-ctx.Done(): return @@ -75,9 +84,12 @@ func TestNormalSimultaneousFetch(t *testing.T) { t.Fatal("Did not collect all peers for request that was completed") } + fpn.queriesMadeMutex.Lock() + defer fpn.queriesMadeMutex.Unlock() if fpn.queriesMade != 2 { t.Fatal("Did not dedup provider requests running simultaneously") } + } func TestDedupingProviderRequests(t *testing.T) { @@ -93,7 +105,7 @@ func TestDedupingProviderRequests(t *testing.T) { sessionID1 := testutil.GenerateSessionID() sessionID2 := testutil.GenerateSessionID() - sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) @@ -115,7 +127,8 @@ func TestDedupingProviderRequests(t *testing.T) { if !reflect.DeepEqual(firstPeersReceived, secondPeersReceived) { t.Fatal("Did not receive the same response to both find provider requests") } - + fpn.queriesMadeMutex.Lock() + defer fpn.queriesMadeMutex.Unlock() if fpn.queriesMade != 1 { t.Fatal("Did not dedup provider requests running simultaneously") } @@ -139,7 +152,7 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond) defer firstCancel() firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, sessionID1) - secondSessionCtx, secondCancel := context.WithTimeout(ctx, 20*time.Millisecond) + secondSessionCtx, secondCancel := context.WithTimeout(ctx, 100*time.Millisecond) defer secondCancel() secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, sessionID2) @@ -160,7 +173,8 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { if len(firstPeersReceived) >= len(peers) { t.Fatal("Collected all peers on cancelled peer, should have been cancelled immediately") } - + fpn.queriesMadeMutex.Lock() + defer fpn.queriesMadeMutex.Unlock() if fpn.queriesMade != 1 { t.Fatal("Did not dedup provider requests running simultaneously") } @@ -248,26 +262,33 @@ func TestRateLimitingRequests(t *testing.T) { delay: 1 * time.Millisecond, } ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() providerQueryManager := New(ctx, fpn) providerQueryManager.Startup() keys := testutil.GenerateCids(maxInProcessRequests + 1) sessionID := testutil.GenerateSessionID() - sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() var requestChannels []<-chan peer.ID for i := 0; i < maxInProcessRequests+1; i++ { requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], sessionID)) } - time.Sleep(2 * time.Millisecond) + time.Sleep(9 * time.Millisecond) + fpn.queriesMadeMutex.Lock() if fpn.queriesMade != maxInProcessRequests { + t.Logf("Queries made: %d\n", fpn.queriesMade) t.Fatal("Did not limit parallel requests to rate limit") } + fpn.queriesMadeMutex.Unlock() for i := 0; i < maxInProcessRequests+1; i++ { for range requestChannels[i] { } } + fpn.queriesMadeMutex.Lock() + defer fpn.queriesMadeMutex.Unlock() if fpn.queriesMade != maxInProcessRequests+1 { t.Fatal("Did not make all seperate requests") } @@ -282,7 +303,7 @@ func TestFindProviderTimeout(t *testing.T) { ctx := context.Background() providerQueryManager := New(ctx, fpn) providerQueryManager.Startup() - providerQueryManager.SetFindProviderTimeout(3 * time.Millisecond) + providerQueryManager.SetFindProviderTimeout(2 * time.Millisecond) keys := testutil.GenerateCids(1) sessionID1 := testutil.GenerateSessionID() @@ -293,8 +314,7 @@ func TestFindProviderTimeout(t *testing.T) { for p := range firstRequestChan { firstPeersReceived = append(firstPeersReceived, p) } - if len(firstPeersReceived) <= 0 || - len(firstPeersReceived) >= len(peers) { + if len(firstPeersReceived) >= len(peers) { t.Fatal("Find provider request should have timed out, did not") } } From 56d9e3fcf95a94dbb255e67c0a2fa8d6ace84dce Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 30 Jan 2019 13:16:51 -0800 Subject: [PATCH 5/9] fix(providequerymanager): improve test stability Removed a minor condition check that could fail in some cases just due to timing, but not a code issue --- providerquerymanager/providerquerymanager_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/providerquerymanager/providerquerymanager_test.go b/providerquerymanager/providerquerymanager_test.go index f5b6db1e..21d7004c 100644 --- a/providerquerymanager/providerquerymanager_test.go +++ b/providerquerymanager/providerquerymanager_test.go @@ -211,9 +211,7 @@ func TestCancelManagerExitsGracefully(t *testing.T) { secondPeersReceived = append(secondPeersReceived, p) } - if len(firstPeersReceived) <= 0 || - len(firstPeersReceived) >= len(peers) || - len(secondPeersReceived) <= 0 || + if len(firstPeersReceived) >= len(peers) || len(secondPeersReceived) >= len(peers) { t.Fatal("Did not cancel requests in progress correctly") } From 92717dbb67953ebee5675555a273b375cbae13d4 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Mon, 4 Feb 2019 11:50:52 -0800 Subject: [PATCH 6/9] refactor(providerquerymanager): don't use session ids removed session id user completely from providerquerymanager --- providerquerymanager/providerquerymanager.go | 45 +++++++++---------- .../providerquerymanager_test.go | 36 +++++---------- sessionpeermanager/sessionpeermanager.go | 6 +-- sessionpeermanager/sessionpeermanager_test.go | 2 +- 4 files changed, 38 insertions(+), 51 deletions(-) diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index 21cfcd0d..8c20b022 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -21,7 +21,7 @@ const ( type inProgressRequestStatus struct { providersSoFar []peer.ID - listeners map[uint64]chan peer.ID + listeners map[chan peer.ID]struct{} } // ProviderQueryNetwork is an interface for finding providers and connecting to @@ -46,14 +46,13 @@ type finishedProviderQueryMessage struct { } type newProvideQueryMessage struct { - ses uint64 k cid.Cid inProgressRequestChan chan<- inProgressRequest } type cancelRequestMessage struct { - ses uint64 - k cid.Cid + incomingProviders chan peer.ID + k cid.Cid } // ProviderQueryManager manages requests to find more providers for blocks @@ -98,7 +97,7 @@ func (pqm *ProviderQueryManager) Startup() { type inProgressRequest struct { providersSoFar []peer.ID - incoming <-chan peer.ID + incoming chan peer.ID } // SetFindProviderTimeout changes the timeout for finding providers @@ -109,12 +108,11 @@ func (pqm *ProviderQueryManager) SetFindProviderTimeout(findProviderTimeout time } // FindProvidersAsync finds providers for the given block. -func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid, ses uint64) <-chan peer.ID { +func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid) <-chan peer.ID { inProgressRequestChan := make(chan inProgressRequest) select { case pqm.providerQueryMessages <- &newProvideQueryMessage{ - ses: ses, k: k, inProgressRequestChan: inProgressRequestChan, }: @@ -131,10 +129,10 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, case receivedInProgressRequest = <-inProgressRequestChan: } - return pqm.receiveProviders(sessionCtx, k, ses, receivedInProgressRequest) + return pqm.receiveProviders(sessionCtx, k, receivedInProgressRequest) } -func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, ses uint64, receivedInProgressRequest inProgressRequest) <-chan peer.ID { +func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, receivedInProgressRequest inProgressRequest) <-chan peer.ID { // maintains an unbuffered queue for incoming providers for given request for a given session // essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all // sessions that queried that CID, without worrying about whether the client code is actually @@ -162,8 +160,8 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k select { case <-sessionCtx.Done(): pqm.providerQueryMessages <- &cancelRequestMessage{ - ses: ses, - k: k, + incomingProviders: incomingProviders, + k: k, } // clear out any remaining providers for range incomingProviders { @@ -269,7 +267,7 @@ func (pqm *ProviderQueryManager) providerRequestBufferWorker() { func (pqm *ProviderQueryManager) cleanupInProcessRequests() { for _, requestStatus := range pqm.inProgressRequestStatuses { - for _, listener := range requestStatus.listeners { + for listener := range requestStatus.listeners { close(listener) } } @@ -305,7 +303,7 @@ func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) { return } requestStatus.providersSoFar = append(requestStatus.providersSoFar, rpm.p) - for _, listener := range requestStatus.listeners { + for listener := range requestStatus.listeners { select { case listener <- rpm.p: case <-pqm.ctx.Done(): @@ -324,21 +322,21 @@ func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) { log.Errorf("Ended request for cid (%s) not in progress", fpqm.k.String()) return } - for _, listener := range requestStatus.listeners { + for listener := range requestStatus.listeners { close(listener) } delete(pqm.inProgressRequestStatuses, fpqm.k) } func (npqm *newProvideQueryMessage) debugMessage() string { - return fmt.Sprintf("New Provider Query on cid: %s from session: %d", npqm.k.String(), npqm.ses) + return fmt.Sprintf("New Provider Query on cid: %s", npqm.k.String()) } func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k] if !ok { requestStatus = &inProgressRequestStatus{ - listeners: make(map[uint64]chan peer.ID), + listeners: make(map[chan peer.ID]struct{}), } pqm.inProgressRequestStatuses[npqm.k] = requestStatus select { @@ -347,31 +345,32 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { return } } - requestStatus.listeners[npqm.ses] = make(chan peer.ID) + inProgressChan := make(chan peer.ID) + requestStatus.listeners[inProgressChan] = struct{}{} select { case npqm.inProgressRequestChan <- inProgressRequest{ providersSoFar: requestStatus.providersSoFar, - incoming: requestStatus.listeners[npqm.ses], + incoming: inProgressChan, }: case <-pqm.ctx.Done(): } } func (crm *cancelRequestMessage) debugMessage() string { - return fmt.Sprintf("Cancel provider query on cid: %s from session: %d", crm.k.String(), crm.ses) + return fmt.Sprintf("Cancel provider query on cid: %s", crm.k.String()) } func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] if !ok { - log.Errorf("Attempt to cancel request for session (%d) for cid (%s) not in progress", crm.ses, crm.k.String()) + log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String()) return } - listener, ok := requestStatus.listeners[crm.ses] + listener := crm.incomingProviders if !ok { - log.Errorf("Attempt to cancel request for session (%d) for cid (%s) this is not a listener", crm.ses, crm.k.String()) + log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) return } close(listener) - delete(requestStatus.listeners, crm.ses) + delete(requestStatus.listeners, listener) } diff --git a/providerquerymanager/providerquerymanager_test.go b/providerquerymanager/providerquerymanager_test.go index 21d7004c..3abe6b0e 100644 --- a/providerquerymanager/providerquerymanager_test.go +++ b/providerquerymanager/providerquerymanager_test.go @@ -62,13 +62,11 @@ func TestNormalSimultaneousFetch(t *testing.T) { providerQueryManager := New(ctx, fpn) providerQueryManager.Startup() keys := testutil.GenerateCids(2) - sessionID1 := testutil.GenerateSessionID() - sessionID2 := testutil.GenerateSessionID() sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], sessionID1) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], sessionID2) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1]) var firstPeersReceived []peer.ID for p := range firstRequestChan { @@ -102,13 +100,11 @@ func TestDedupingProviderRequests(t *testing.T) { providerQueryManager := New(ctx, fpn) providerQueryManager.Startup() key := testutil.GenerateCids(1)[0] - sessionID1 := testutil.GenerateSessionID() - sessionID2 := testutil.GenerateSessionID() sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) var firstPeersReceived []peer.ID for p := range firstRequestChan { @@ -145,16 +141,14 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { providerQueryManager.Startup() key := testutil.GenerateCids(1)[0] - sessionID1 := testutil.GenerateSessionID() - sessionID2 := testutil.GenerateSessionID() // first session will cancel before done firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond) defer firstCancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, sessionID1) + firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key) secondSessionCtx, secondCancel := context.WithTimeout(ctx, 100*time.Millisecond) defer secondCancel() - secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, sessionID2) + secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key) var firstPeersReceived []peer.ID for p := range firstRequestChan { @@ -193,13 +187,11 @@ func TestCancelManagerExitsGracefully(t *testing.T) { providerQueryManager.Startup() key := testutil.GenerateCids(1)[0] - sessionID1 := testutil.GenerateSessionID() - sessionID2 := testutil.GenerateSessionID() sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) var firstPeersReceived []peer.ID for p := range firstRequestChan { @@ -229,13 +221,11 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { providerQueryManager.Startup() key := testutil.GenerateCids(1)[0] - sessionID1 := testutil.GenerateSessionID() - sessionID2 := testutil.GenerateSessionID() sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) var firstPeersReceived []peer.ID for p := range firstRequestChan { @@ -266,12 +256,11 @@ func TestRateLimitingRequests(t *testing.T) { providerQueryManager.Startup() keys := testutil.GenerateCids(maxInProcessRequests + 1) - sessionID := testutil.GenerateSessionID() sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() var requestChannels []<-chan peer.ID for i := 0; i < maxInProcessRequests+1; i++ { - requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], sessionID)) + requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i])) } time.Sleep(9 * time.Millisecond) fpn.queriesMadeMutex.Lock() @@ -303,11 +292,10 @@ func TestFindProviderTimeout(t *testing.T) { providerQueryManager.Startup() providerQueryManager.SetFindProviderTimeout(2 * time.Millisecond) keys := testutil.GenerateCids(1) - sessionID1 := testutil.GenerateSessionID() sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], sessionID1) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) var firstPeersReceived []peer.ID for p := range firstRequestChan { firstPeersReceived = append(firstPeersReceived, p) diff --git a/sessionpeermanager/sessionpeermanager.go b/sessionpeermanager/sessionpeermanager.go index 091e1c7e..0b02a2a2 100644 --- a/sessionpeermanager/sessionpeermanager.go +++ b/sessionpeermanager/sessionpeermanager.go @@ -26,7 +26,7 @@ type PeerTagger interface { // PeerProviderFinder is an interface for finding providers type PeerProviderFinder interface { - FindProvidersAsync(context.Context, cid.Cid, uint64) <-chan peer.ID + FindProvidersAsync(context.Context, cid.Cid) <-chan peer.ID } type peerMessage interface { @@ -108,8 +108,8 @@ func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID { // providers for the given Cid func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) { go func(k cid.Cid) { - for p := range spm.providerFinder.FindProvidersAsync(ctx, k, spm.id) { - + for p := range spm.providerFinder.FindProvidersAsync(ctx, k) { + select { case spm.peerMessages <- &peerFoundMessage{p}: case <-ctx.Done(): diff --git a/sessionpeermanager/sessionpeermanager_test.go b/sessionpeermanager/sessionpeermanager_test.go index 68862942..d6d1440a 100644 --- a/sessionpeermanager/sessionpeermanager_test.go +++ b/sessionpeermanager/sessionpeermanager_test.go @@ -18,7 +18,7 @@ type fakePeerProviderFinder struct { completed chan struct{} } -func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid, ses uint64) <-chan peer.ID { +func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid) <-chan peer.ID { peerCh := make(chan peer.ID) go func() { From 51e82a6552f657f91cd28b91682e4ff456182336 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Mon, 4 Feb 2019 12:31:20 -0800 Subject: [PATCH 7/9] fix(providerquerymanager): minor fixes to capture all cancellations --- providerquerymanager/providerquerymanager.go | 36 +++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index 8c20b022..26602bc5 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -124,6 +124,8 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, var receivedInProgressRequest inProgressRequest select { + case <-pqm.ctx.Done(): + return nil case <-sessionCtx.Done(): return nil case receivedInProgressRequest = <-inProgressRequestChan: @@ -158,15 +160,25 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k } for len(receivedProviders) > 0 || incomingProviders != nil { select { + case <-pqm.ctx.Done(): + return case <-sessionCtx.Done(): pqm.providerQueryMessages <- &cancelRequestMessage{ incomingProviders: incomingProviders, k: k, } - // clear out any remaining providers - for range incomingProviders { + // clear out any remaining providers, in case and "incoming provider" + // messages get processed before our cancel message + for { + select { + case _, ok := <-incomingProviders: + if !ok { + return + } + case <-pqm.ctx.Done(): + return + } } - return case provider, ok := <-incomingProviders: if !ok { incomingProviders = nil @@ -362,15 +374,15 @@ func (crm *cancelRequestMessage) debugMessage() string { func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] - if !ok { + if ok { + _, ok := requestStatus.listeners[crm.incomingProviders] + if ok { + delete(requestStatus.listeners, crm.incomingProviders) + } else { + log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) + } + } else { log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String()) - return - } - listener := crm.incomingProviders - if !ok { - log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) - return } - close(listener) - delete(requestStatus.listeners, listener) + close(crm.incomingProviders) } From b48b3c33ee4ecacff165220fea06520efb21d45d Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Mon, 4 Feb 2019 14:58:46 -0800 Subject: [PATCH 8/9] feat(providerquerymanager): cancel FindProvidersAsync correctly Make sure if all requestors cancel their request to find providers on a peer, the overall query gets cancelled --- providerquerymanager/providerquerymanager.go | 43 ++++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index 26602bc5..b84463a7 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -20,10 +20,17 @@ const ( ) type inProgressRequestStatus struct { + ctx context.Context + cancelFn func() providersSoFar []peer.ID listeners map[chan peer.ID]struct{} } +type findProviderRequest struct { + k cid.Cid + ctx context.Context +} + // ProviderQueryNetwork is an interface for finding providers and connecting to // peers. type ProviderQueryNetwork interface { @@ -66,8 +73,8 @@ type ProviderQueryManager struct { ctx context.Context network ProviderQueryNetwork providerQueryMessages chan providerQueryMessage - providerRequestsProcessing chan cid.Cid - incomingFindProviderRequests chan cid.Cid + providerRequestsProcessing chan *findProviderRequest + incomingFindProviderRequests chan *findProviderRequest findProviderTimeout time.Duration timeoutMutex sync.RWMutex @@ -83,8 +90,8 @@ func New(ctx context.Context, network ProviderQueryNetwork) *ProviderQueryManage ctx: ctx, network: network, providerQueryMessages: make(chan providerQueryMessage, 16), - providerRequestsProcessing: make(chan cid.Cid), - incomingFindProviderRequests: make(chan cid.Cid), + providerRequestsProcessing: make(chan *findProviderRequest), + incomingFindProviderRequests: make(chan *findProviderRequest), inProgressRequestStatuses: make(map[cid.Cid]*inProgressRequestStatus), findProviderTimeout: defaultTimeout, } @@ -199,14 +206,14 @@ func (pqm *ProviderQueryManager) findProviderWorker() { // to let requests go in parallel but keep them rate limited for { select { - case k, ok := <-pqm.providerRequestsProcessing: + case fpr, ok := <-pqm.providerRequestsProcessing: if !ok { return } - + k := fpr.k log.Debugf("Beginning Find Provider Request for cid: %s", k.String()) pqm.timeoutMutex.RLock() - findProviderCtx, cancel := context.WithTimeout(pqm.ctx, pqm.findProviderTimeout) + findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout) pqm.timeoutMutex.RUnlock() defer cancel() providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders) @@ -248,14 +255,14 @@ func (pqm *ProviderQueryManager) providerRequestBufferWorker() { // buffer for incoming provider queries and dispatches to the find // provider workers as they become available // based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd - var providerQueryRequestBuffer []cid.Cid - nextProviderQuery := func() cid.Cid { + var providerQueryRequestBuffer []*findProviderRequest + nextProviderQuery := func() *findProviderRequest { if len(providerQueryRequestBuffer) == 0 { - return cid.Cid{} + return nil } return providerQueryRequestBuffer[0] } - outgoingRequests := func() chan<- cid.Cid { + outgoingRequests := func() chan<- *findProviderRequest { if len(providerQueryRequestBuffer) == 0 { return nil } @@ -282,6 +289,7 @@ func (pqm *ProviderQueryManager) cleanupInProcessRequests() { for listener := range requestStatus.listeners { close(listener) } + requestStatus.cancelFn() } } @@ -338,6 +346,7 @@ func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) { close(listener) } delete(pqm.inProgressRequestStatuses, fpqm.k) + requestStatus.cancelFn() } func (npqm *newProvideQueryMessage) debugMessage() string { @@ -347,12 +356,18 @@ func (npqm *newProvideQueryMessage) debugMessage() string { func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k] if !ok { + ctx, cancelFn := context.WithCancel(pqm.ctx) requestStatus = &inProgressRequestStatus{ listeners: make(map[chan peer.ID]struct{}), + ctx: ctx, + cancelFn: cancelFn, } pqm.inProgressRequestStatuses[npqm.k] = requestStatus select { - case pqm.incomingFindProviderRequests <- npqm.k: + case pqm.incomingFindProviderRequests <- &findProviderRequest{ + k: npqm.k, + ctx: ctx, + }: case <-pqm.ctx.Done(): return } @@ -378,6 +393,10 @@ func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { _, ok := requestStatus.listeners[crm.incomingProviders] if ok { delete(requestStatus.listeners, crm.incomingProviders) + if len(requestStatus.listeners) == 0 { + delete(pqm.inProgressRequestStatuses, crm.k) + requestStatus.cancelFn() + } } else { log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) } From 30f40ecec4f34dd7637f78b0b90dff6e25208be2 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Tue, 5 Feb 2019 10:56:16 -0800 Subject: [PATCH 9/9] fix(providerquerymanager): minor channel cleanup Keep channels unblocked in cancelling request -- refactored to function. Also cancel find provider context as soon as it can be. --- providerquerymanager/providerquerymanager.go | 65 +++++++++++--------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index b84463a7..38471479 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -170,22 +170,8 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k case <-pqm.ctx.Done(): return case <-sessionCtx.Done(): - pqm.providerQueryMessages <- &cancelRequestMessage{ - incomingProviders: incomingProviders, - k: k, - } - // clear out any remaining providers, in case and "incoming provider" - // messages get processed before our cancel message - for { - select { - case _, ok := <-incomingProviders: - if !ok { - return - } - case <-pqm.ctx.Done(): - return - } - } + pqm.cancelProviderRequest(k, incomingProviders) + return case provider, ok := <-incomingProviders: if !ok { incomingProviders = nil @@ -200,6 +186,27 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k return returnedProviders } +func (pqm *ProviderQueryManager) cancelProviderRequest(k cid.Cid, incomingProviders chan peer.ID) { + cancelMessageChannel := pqm.providerQueryMessages + for { + select { + case cancelMessageChannel <- &cancelRequestMessage{ + incomingProviders: incomingProviders, + k: k, + }: + cancelMessageChannel = nil + // clear out any remaining providers, in case and "incoming provider" + // messages get processed before our cancel message + case _, ok := <-incomingProviders: + if !ok { + return + } + case <-pqm.ctx.Done(): + return + } + } +} + func (pqm *ProviderQueryManager) findProviderWorker() { // findProviderWorker just cycles through incoming provider queries one // at a time. We have six of these workers running at once @@ -215,7 +222,6 @@ func (pqm *ProviderQueryManager) findProviderWorker() { pqm.timeoutMutex.RLock() findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout) pqm.timeoutMutex.RUnlock() - defer cancel() providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders) wg := &sync.WaitGroup{} for p := range providers { @@ -237,6 +243,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { } }(p) } + cancel() wg.Wait() select { case pqm.providerQueryMessages <- &finishedProviderQueryMessage{ @@ -389,19 +396,19 @@ func (crm *cancelRequestMessage) debugMessage() string { func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] - if ok { - _, ok := requestStatus.listeners[crm.incomingProviders] - if ok { - delete(requestStatus.listeners, crm.incomingProviders) - if len(requestStatus.listeners) == 0 { - delete(pqm.inProgressRequestStatuses, crm.k) - requestStatus.cancelFn() - } - } else { - log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) - } - } else { + if !ok { log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String()) + return } + _, ok = requestStatus.listeners[crm.incomingProviders] + if !ok { + log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) + return + } + delete(requestStatus.listeners, crm.incomingProviders) close(crm.incomingProviders) + if len(requestStatus.listeners) == 0 { + delete(pqm.inProgressRequestStatuses, crm.k) + requestStatus.cancelFn() + } }