Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

Commit

Permalink
fix(ProviderQueryManager): fix test + add logging
Browse files Browse the repository at this point in the history
Add debug logging for the provider query manager and make tests more reliable
  • Loading branch information
hannahhoward committed Jan 23, 2019
1 parent 464ce48 commit 7478032
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
22 changes: 20 additions & 2 deletions providerquerymanager/providerquerymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package providerquerymanager

import (
"context"
"fmt"
"sync"
"time"

Expand Down Expand Up @@ -31,6 +32,7 @@ type ProviderQueryNetwork interface {
}

type providerQueryMessage interface {
debugMessage() string
handle(pqm *ProviderQueryManager)
}

Expand Down Expand Up @@ -192,6 +194,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
return
}

log.Debugf("Beginning Find Provider Request for cid: %s", k.String())
pqm.timeoutMutex.RLock()
findProviderCtx, cancel := context.WithTimeout(pqm.ctx, pqm.findProviderTimeout)
pqm.timeoutMutex.RUnlock()
Expand Down Expand Up @@ -273,8 +276,6 @@ func (pqm *ProviderQueryManager) cleanupInProcessRequests() {
}

func (pqm *ProviderQueryManager) run() {
defer close(pqm.incomingFindProviderRequests)
defer close(pqm.providerRequestsProcessing)
defer pqm.cleanupInProcessRequests()

go pqm.providerRequestBufferWorker()
Expand All @@ -285,13 +286,18 @@ func (pqm *ProviderQueryManager) run() {
for {
select {
case nextMessage := <-pqm.providerQueryMessages:
log.Debug(nextMessage.debugMessage())
nextMessage.handle(pqm)
case <-pqm.ctx.Done():
return
}
}
}

func (rpm *receivedProviderMessage) debugMessage() string {
return fmt.Sprintf("Received provider (%s) for cid (%s)", rpm.p.String(), rpm.k.String())
}

func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[rpm.k]
if !ok {
Expand All @@ -308,6 +314,10 @@ func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) {
}
}

func (fpqm *finishedProviderQueryMessage) debugMessage() string {
return fmt.Sprintf("Finished Provider Query on cid: %s", fpqm.k.String())
}

func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[fpqm.k]
if !ok {
Expand All @@ -320,6 +330,10 @@ func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) {
delete(pqm.inProgressRequestStatuses, fpqm.k)
}

func (npqm *newProvideQueryMessage) debugMessage() string {
return fmt.Sprintf("New Provider Query on cid: %s from session: %d", npqm.k.String(), npqm.ses)
}

func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k]
if !ok {
Expand All @@ -343,6 +357,10 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
}
}

func (crm *cancelRequestMessage) debugMessage() string {
return fmt.Sprintf("Cancel provider query on cid: %s from session: %d", crm.k.String(), crm.ses)
}

func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
if !ok {
Expand Down
48 changes: 34 additions & 14 deletions providerquerymanager/providerquerymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"reflect"
"sync"
"testing"
"time"

Expand All @@ -14,11 +15,12 @@ import (
)

type fakeProviderNetwork struct {
peersFound []peer.ID
connectError error
delay time.Duration
connectDelay time.Duration
queriesMade int
peersFound []peer.ID
connectError error
delay time.Duration
connectDelay time.Duration
queriesMadeMutex sync.RWMutex
queriesMade int
}

func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error {
Expand All @@ -27,13 +29,20 @@ func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error {
}

func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.ID {
fpn.queriesMadeMutex.Lock()
fpn.queriesMade++
fpn.queriesMadeMutex.Unlock()
incomingPeers := make(chan peer.ID)
go func() {
defer close(incomingPeers)
for _, p := range fpn.peersFound {
time.Sleep(fpn.delay)
select {
case <-ctx.Done():
return
default:
}
select {
case incomingPeers <- p:
case <-ctx.Done():
return
Expand Down Expand Up @@ -75,9 +84,12 @@ func TestNormalSimultaneousFetch(t *testing.T) {
t.Fatal("Did not collect all peers for request that was completed")
}

fpn.queriesMadeMutex.Lock()
defer fpn.queriesMadeMutex.Unlock()
if fpn.queriesMade != 2 {
t.Fatal("Did not dedup provider requests running simultaneously")
}

}

func TestDedupingProviderRequests(t *testing.T) {
Expand All @@ -93,7 +105,7 @@ func TestDedupingProviderRequests(t *testing.T) {
sessionID1 := testutil.GenerateSessionID()
sessionID2 := testutil.GenerateSessionID()

sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2)
Expand All @@ -115,7 +127,8 @@ func TestDedupingProviderRequests(t *testing.T) {
if !reflect.DeepEqual(firstPeersReceived, secondPeersReceived) {
t.Fatal("Did not receive the same response to both find provider requests")
}

fpn.queriesMadeMutex.Lock()
defer fpn.queriesMadeMutex.Unlock()
if fpn.queriesMade != 1 {
t.Fatal("Did not dedup provider requests running simultaneously")
}
Expand All @@ -139,7 +152,7 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond)
defer firstCancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, sessionID1)
secondSessionCtx, secondCancel := context.WithTimeout(ctx, 20*time.Millisecond)
secondSessionCtx, secondCancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer secondCancel()
secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, sessionID2)

Expand All @@ -160,7 +173,8 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
if len(firstPeersReceived) >= len(peers) {
t.Fatal("Collected all peers on cancelled peer, should have been cancelled immediately")
}

fpn.queriesMadeMutex.Lock()
defer fpn.queriesMadeMutex.Unlock()
if fpn.queriesMade != 1 {
t.Fatal("Did not dedup provider requests running simultaneously")
}
Expand Down Expand Up @@ -248,26 +262,33 @@ func TestRateLimitingRequests(t *testing.T) {
delay: 1 * time.Millisecond,
}
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()

keys := testutil.GenerateCids(maxInProcessRequests + 1)
sessionID := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
var requestChannels []<-chan peer.ID
for i := 0; i < maxInProcessRequests+1; i++ {
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], sessionID))
}
time.Sleep(2 * time.Millisecond)
time.Sleep(9 * time.Millisecond)
fpn.queriesMadeMutex.Lock()
if fpn.queriesMade != maxInProcessRequests {
t.Logf("Queries made: %d\n", fpn.queriesMade)
t.Fatal("Did not limit parallel requests to rate limit")
}
fpn.queriesMadeMutex.Unlock()
for i := 0; i < maxInProcessRequests+1; i++ {
for range requestChannels[i] {
}
}

fpn.queriesMadeMutex.Lock()
defer fpn.queriesMadeMutex.Unlock()
if fpn.queriesMade != maxInProcessRequests+1 {
t.Fatal("Did not make all seperate requests")
}
Expand All @@ -282,7 +303,7 @@ func TestFindProviderTimeout(t *testing.T) {
ctx := context.Background()
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
providerQueryManager.SetFindProviderTimeout(3 * time.Millisecond)
providerQueryManager.SetFindProviderTimeout(2 * time.Millisecond)
keys := testutil.GenerateCids(1)
sessionID1 := testutil.GenerateSessionID()

Expand All @@ -293,8 +314,7 @@ func TestFindProviderTimeout(t *testing.T) {
for p := range firstRequestChan {
firstPeersReceived = append(firstPeersReceived, p)
}
if len(firstPeersReceived) <= 0 ||
len(firstPeersReceived) >= len(peers) {
if len(firstPeersReceived) >= len(peers) {
t.Fatal("Find provider request should have timed out, did not")
}
}

0 comments on commit 7478032

Please sign in to comment.