diff --git a/backoff.go b/backoff.go new file mode 100644 index 0000000..ecd9470 --- /dev/null +++ b/backoff.go @@ -0,0 +1,206 @@ +package discovery + +import ( + "math" + "math/rand" + "time" +) + +type BackoffFactory func() BackoffStrategy + +// BackoffStrategy describes how backoff will be implemented. BackoffStratgies are stateful. +type BackoffStrategy interface { + // Delay calculates how long the next backoff duration should be, given the prior calls to Delay + Delay() time.Duration + // Reset clears the internal state of the BackoffStrategy + Reset() +} + +// Jitter implementations taken roughly from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + +// Jitter must return a duration between min and max. Min must be lower than, or equal to, max. +type Jitter func(duration, min, max time.Duration, rng *rand.Rand) time.Duration + +// FullJitter returns a random number uniformly chose from the range [min, boundedDur]. +// boundedDur is the duration bounded between min and max. +func FullJitter(duration, min, max time.Duration, rng *rand.Rand) time.Duration { + if duration <= min { + return min + } + + normalizedDur := boundedDuration(duration, min, max) - min + + return boundedDuration(time.Duration(rng.Int63n(int64(normalizedDur)))+min, min, max) +} + +// NoJitter returns the duration bounded between min and max +func NoJitter(duration, min, max time.Duration, rng *rand.Rand) time.Duration { + return boundedDuration(duration, min, max) +} + +type randomizedBackoff struct { + min time.Duration + max time.Duration + rng *rand.Rand +} + +func (b *randomizedBackoff) BoundedDelay(duration time.Duration) time.Duration { + return boundedDuration(duration, b.min, b.max) +} + +func boundedDuration(d, min, max time.Duration) time.Duration { + if d < min { + return min + } + if d > max { + return max + } + return d +} + +type attemptBackoff struct { + attempt int + jitter Jitter + randomizedBackoff +} + +func (b *attemptBackoff) Reset() { + b.attempt = 0 +} + +// NewFixedBackoff creates a BackoffFactory with a constant backoff duration +func NewFixedBackoff(delay time.Duration) BackoffFactory { + return func() BackoffStrategy { + return &fixedBackoff{delay: delay} + } +} + +type fixedBackoff struct { + delay time.Duration +} + +func (b *fixedBackoff) Delay() time.Duration { + return b.delay +} + +func (b *fixedBackoff) Reset() {} + +// NewPolynomialBackoff creates a BackoffFactory with backoff of the form c0*x^0, c1*x^1, ...cn*x^n where x is the attempt number +// jitter is the function for adding randomness around the backoff +// timeUnits are the units of time the polynomial is evaluated in +// polyCoefs is the array of polynomial coefficients from [c0, c1, ... cn] +func NewPolynomialBackoff(min, max time.Duration, jitter Jitter, + timeUnits time.Duration, polyCoefs []float64, rng *rand.Rand) BackoffFactory { + return func() BackoffStrategy { + return &polynomialBackoff{ + attemptBackoff: attemptBackoff{ + randomizedBackoff: randomizedBackoff{ + min: min, + max: max, + rng: rng, + }, + jitter: jitter, + }, + timeUnits: timeUnits, + poly: polyCoefs, + } + } +} + +type polynomialBackoff struct { + attemptBackoff + timeUnits time.Duration + poly []float64 +} + +func (b *polynomialBackoff) Delay() time.Duration { + var polySum float64 + switch len(b.poly) { + case 0: + return 0 + case 1: + polySum = b.poly[0] + default: + polySum = b.poly[0] + exp := 1 + attempt := b.attempt + b.attempt++ + + for _, c := range b.poly[1:] { + exp *= attempt + polySum += float64(exp) * c + } + } + return b.jitter(time.Duration(float64(b.timeUnits)*polySum), b.min, b.max, b.rng) +} + +// NewExponentialBackoff creates a BackoffFactory with backoff of the form base^x + offset where x is the attempt number +// jitter is the function for adding randomness around the backoff +// timeUnits are the units of time the base^x is evaluated in +func NewExponentialBackoff(min, max time.Duration, jitter Jitter, + timeUnits time.Duration, base float64, offset time.Duration, rng *rand.Rand) BackoffFactory { + return func() BackoffStrategy { + return &exponentialBackoff{ + attemptBackoff: attemptBackoff{ + randomizedBackoff: randomizedBackoff{ + min: min, + max: max, + rng: rng, + }, + jitter: jitter, + }, + timeUnits: timeUnits, + base: base, + offset: offset, + } + } +} + +type exponentialBackoff struct { + attemptBackoff + timeUnits time.Duration + base float64 + offset time.Duration +} + +func (b *exponentialBackoff) Delay() time.Duration { + attempt := b.attempt + b.attempt++ + return b.jitter( + time.Duration(math.Pow(b.base, float64(attempt))*float64(b.timeUnits))+b.offset, b.min, b.max, b.rng) +} + +// NewExponentialDecorrelatedJitter creates a BackoffFactory with backoff of the roughly of the form base^x where x is the attempt number. +// Delays start at the minimum duration and after each attempt delay = rand(min, delay * base), bounded by the max +// See https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ for more information +func NewExponentialDecorrelatedJitter(min, max time.Duration, base float64, rng *rand.Rand) BackoffFactory { + return func() BackoffStrategy { + return &exponentialDecorrelatedJitter{ + randomizedBackoff: randomizedBackoff{ + min: min, + max: max, + rng: rng, + }, + base: base, + } + } +} + +type exponentialDecorrelatedJitter struct { + randomizedBackoff + base float64 + lastDelay time.Duration +} + +func (b *exponentialDecorrelatedJitter) Delay() time.Duration { + if b.lastDelay < b.min { + b.lastDelay = b.min + return b.lastDelay + } + + nextMax := int64(float64(b.lastDelay) * b.base) + b.lastDelay = boundedDuration(time.Duration(b.rng.Int63n(nextMax-int64(b.min)))+b.min, b.min, b.max) + return b.lastDelay +} + +func (b *exponentialDecorrelatedJitter) Reset() { b.lastDelay = 0 } diff --git a/backoff_test.go b/backoff_test.go new file mode 100644 index 0000000..655d720 --- /dev/null +++ b/backoff_test.go @@ -0,0 +1,125 @@ +package discovery + +import ( + "math/rand" + "testing" + "time" +) + +func checkDelay(bkf BackoffStrategy, expected time.Duration, t *testing.T) { + t.Helper() + if calculated := bkf.Delay(); calculated != expected { + t.Fatalf("expected %v, got %v", expected, calculated) + } +} + +func TestFixedBackoff(t *testing.T) { + startDelay := time.Second + delay := startDelay + + bkf := NewFixedBackoff(delay) + delay *= 2 + b1 := bkf() + delay *= 2 + b2 := bkf() + + if b1.Delay() != startDelay || b2.Delay() != startDelay { + t.Fatal("incorrect delay time") + } + + if b1.Delay() != startDelay { + t.Fatal("backoff is stateful") + } + + if b1.Reset(); b1.Delay() != startDelay { + t.Fatalf("Reset does something") + } +} + +func TestPolynomialBackoff(t *testing.T) { + rng := rand.New(rand.NewSource(0)) + bkf := NewPolynomialBackoff(time.Second, time.Second*33, NoJitter, time.Second, []float64{0.5, 2, 3}, rng) + b1 := bkf() + b2 := bkf() + + if b1.Delay() != time.Second || b2.Delay() != time.Second { + t.Fatal("incorrect delay time") + } + + checkDelay(b1, time.Millisecond*5500, t) + checkDelay(b1, time.Millisecond*16500, t) + checkDelay(b1, time.Millisecond*33000, t) + checkDelay(b2, time.Millisecond*5500, t) + + b1.Reset() + b1.Delay() + checkDelay(b1, time.Millisecond*5500, t) +} + +func TestExponentialBackoff(t *testing.T) { + rng := rand.New(rand.NewSource(0)) + bkf := NewExponentialBackoff(time.Millisecond*650, time.Second*7, NoJitter, time.Second, 1.5, -time.Millisecond*400, rng) + b1 := bkf() + b2 := bkf() + + if b1.Delay() != time.Millisecond*650 || b2.Delay() != time.Millisecond*650 { + t.Fatal("incorrect delay time") + } + + checkDelay(b1, time.Millisecond*1100, t) + checkDelay(b1, time.Millisecond*1850, t) + checkDelay(b1, time.Millisecond*2975, t) + checkDelay(b1, time.Microsecond*4662500, t) + checkDelay(b1, time.Second*7, t) + checkDelay(b2, time.Millisecond*1100, t) + + b1.Reset() + b1.Delay() + checkDelay(b1, time.Millisecond*1100, t) +} + +func minMaxJitterTest(jitter Jitter, t *testing.T) { + rng := rand.New(rand.NewSource(0)) + if jitter(time.Nanosecond, time.Hour*10, time.Hour*20, rng) < time.Hour*10 { + t.Fatal("Min not working") + } + if jitter(time.Hour, time.Nanosecond, time.Nanosecond*10, rng) > time.Nanosecond*10 { + t.Fatal("Max not working") + } +} + +func TestNoJitter(t *testing.T) { + minMaxJitterTest(NoJitter, t) + for i := 0; i < 10; i++ { + expected := time.Second * time.Duration(i) + if calculated := NoJitter(expected, time.Duration(0), time.Second*100, nil); calculated != expected { + t.Fatalf("expected %v, got %v", expected, calculated) + } + } +} + +func TestFullJitter(t *testing.T) { + rng := rand.New(rand.NewSource(0)) + minMaxJitterTest(FullJitter, t) + const numBuckets = 51 + const multiplier = 10 + const threshold = 20 + + histogram := make([]int, numBuckets) + + for i := 0; i < (numBuckets-1)*multiplier; i++ { + started := time.Nanosecond * 50 + calculated := FullJitter(started, 0, 100, rng) + histogram[calculated]++ + } + + for _, count := range histogram { + if count > threshold { + t.Fatal("jitter is not close to evenly spread") + } + } + + if histogram[numBuckets-1] > 0 { + t.Fatal("jitter increased overall time") + } +} diff --git a/backoffcache.go b/backoffcache.go new file mode 100644 index 0000000..0e255e2 --- /dev/null +++ b/backoffcache.go @@ -0,0 +1,297 @@ +package discovery + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/discovery" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-peerstore/addr" +) + +// BackoffDiscovery is an implementation of discovery that caches peer data and attenuates repeated queries +type BackoffDiscovery struct { + disc discovery.Discovery + stratFactory BackoffFactory + peerCache map[string]*backoffCache + peerCacheMux sync.RWMutex + + parallelBufSz int + returnedBufSz int +} + +type BackoffDiscoveryOption func(*BackoffDiscovery) error + +func NewBackoffDiscovery(disc discovery.Discovery, stratFactory BackoffFactory, opts ...BackoffDiscoveryOption) (discovery.Discovery, error) { + b := &BackoffDiscovery{ + disc: disc, + stratFactory: stratFactory, + peerCache: make(map[string]*backoffCache), + + parallelBufSz: 32, + returnedBufSz: 32, + } + + for _, opt := range opts { + if err := opt(b); err != nil { + return nil, err + } + } + + return b, nil +} + +// WithBackoffDiscoverySimultaneousQueryBufferSize sets the buffer size for the channels between the main FindPeers query +// for a given namespace and all simultaneous FindPeers queries for the namespace +func WithBackoffDiscoverySimultaneousQueryBufferSize(size int) BackoffDiscoveryOption { + return func(b *BackoffDiscovery) error { + if size < 0 { + return fmt.Errorf("cannot set size to be smaller than 0") + } + b.parallelBufSz = size + return nil + } +} + +// WithBackoffDiscoveryReturnedChannelSize sets the size of the buffer to be used during a FindPeer query. +// Note: This does not apply if the query occurs during the backoff time +func WithBackoffDiscoveryReturnedChannelSize(size int) BackoffDiscoveryOption { + return func(b *BackoffDiscovery) error { + if size < 0 { + return fmt.Errorf("cannot set size to be smaller than 0") + } + b.returnedBufSz = size + return nil + } +} + +type backoffCache struct { + nextDiscover time.Time + prevPeers map[peer.ID]peer.AddrInfo + + peers map[peer.ID]peer.AddrInfo + sendingChs map[chan peer.AddrInfo]int + + ongoing bool + strat BackoffStrategy + mux sync.Mutex +} + +func (d *BackoffDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { + return d.disc.Advertise(ctx, ns, opts...) +} + +func (d *BackoffDiscovery) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { + // Get options + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return nil, err + } + + // Get cached peers + d.peerCacheMux.RLock() + c, ok := d.peerCache[ns] + d.peerCacheMux.RUnlock() + + /* + Overall plan: + If it's time to look for peers, look for peers, then return them + If it's not time then return cache + If it's time to look for peers, but we have already started looking. Get up to speed with ongoing request + */ + + // Setup cache if we don't have one yet + if !ok { + pc := &backoffCache{ + nextDiscover: time.Time{}, + prevPeers: make(map[peer.ID]peer.AddrInfo), + peers: make(map[peer.ID]peer.AddrInfo), + sendingChs: make(map[chan peer.AddrInfo]int), + strat: d.stratFactory(), + } + d.peerCacheMux.Lock() + c, ok = d.peerCache[ns] + + if !ok { + d.peerCache[ns] = pc + c = pc + } + + d.peerCacheMux.Unlock() + } + + c.mux.Lock() + defer c.mux.Unlock() + + timeExpired := time.Now().After(c.nextDiscover) + + // If it's not yet time to search again and no searches are in progress then return cached peers + if !(timeExpired || c.ongoing) { + chLen := options.Limit + + if chLen == 0 { + chLen = len(c.prevPeers) + } else if chLen > len(c.prevPeers) { + chLen = len(c.prevPeers) + } + pch := make(chan peer.AddrInfo, chLen) + for _, ai := range c.prevPeers { + pch <- ai + } + close(pch) + return pch, nil + } + + // If a request is not already in progress setup a dispatcher channel for dispatching incoming peers + if !c.ongoing { + pch, err := d.disc.FindPeers(ctx, ns, opts...) + if err != nil { + return nil, err + } + + c.ongoing = true + go findPeerDispatcher(ctx, c, pch) + } + + // Setup receiver channel for receiving peers from ongoing requests + evtCh := make(chan peer.AddrInfo, d.parallelBufSz) + pch := make(chan peer.AddrInfo, d.returnedBufSz) + rcvPeers := make([]peer.AddrInfo, 0, 32) + for _, ai := range c.peers { + rcvPeers = append(rcvPeers, ai) + } + c.sendingChs[evtCh] = options.Limit + + go findPeerReceiver(ctx, pch, evtCh, rcvPeers) + + return pch, nil +} + +func findPeerDispatcher(ctx context.Context, c *backoffCache, pch <-chan peer.AddrInfo) { + defer func() { + c.mux.Lock() + + for ch := range c.sendingChs { + close(ch) + } + + // If the peer addresses have changed reset the backoff + if checkUpdates(c.prevPeers, c.peers) { + c.strat.Reset() + c.prevPeers = c.peers + } + c.nextDiscover = time.Now().Add(c.strat.Delay()) + + c.ongoing = false + c.peers = make(map[peer.ID]peer.AddrInfo) + c.sendingChs = make(map[chan peer.AddrInfo]int) + c.mux.Unlock() + }() + + for { + select { + case ai, ok := <-pch: + if !ok { + return + } + c.mux.Lock() + + // If we receive the same peer multiple times return the address union + var sendAi peer.AddrInfo + if prevAi, ok := c.peers[ai.ID]; ok { + if combinedAi := mergeAddrInfos(prevAi, ai); combinedAi != nil { + sendAi = *combinedAi + } else { + c.mux.Unlock() + continue + } + } else { + sendAi = ai + } + + c.peers[ai.ID] = sendAi + + for ch, rem := range c.sendingChs { + ch <- sendAi + if rem == 1 { + close(ch) + delete(c.sendingChs, ch) + break + } else if rem > 0 { + rem-- + } + } + + c.mux.Unlock() + case <-ctx.Done(): + return + } + } +} + +func findPeerReceiver(ctx context.Context, pch, evtCh chan peer.AddrInfo, rcvPeers []peer.AddrInfo) { + defer close(pch) + + for { + select { + case ai, ok := <-evtCh: + if ok { + rcvPeers = append(rcvPeers, ai) + + sentAll := true + sendPeers: + for i, p := range rcvPeers { + select { + case pch <- p: + default: + rcvPeers = rcvPeers[i:] + sentAll = false + break sendPeers + } + } + if sentAll { + rcvPeers = []peer.AddrInfo{} + } + } else { + for _, p := range rcvPeers { + select { + case pch <- p: + case <-ctx.Done(): + return + } + } + return + } + case <-ctx.Done(): + return + } + } +} + +func mergeAddrInfos(prevAi, newAi peer.AddrInfo) *peer.AddrInfo { + combinedAddrs := addr.UniqueSource(addr.Slice(prevAi.Addrs), addr.Slice(newAi.Addrs)).Addrs() + if len(combinedAddrs) > len(prevAi.Addrs) { + combinedAi := &peer.AddrInfo{ID: prevAi.ID, Addrs: combinedAddrs} + return combinedAi + } + return nil +} + +func checkUpdates(orig, update map[peer.ID]peer.AddrInfo) bool { + if len(orig) != len(update) { + return true + } + for p, ai := range update { + if prevAi, ok := orig[p]; ok { + if combinedAi := mergeAddrInfos(prevAi, ai); combinedAi != nil { + return true + } + } else { + return true + } + } + return false +} diff --git a/backoffcache_test.go b/backoffcache_test.go new file mode 100644 index 0000000..5d7e5dd --- /dev/null +++ b/backoffcache_test.go @@ -0,0 +1,194 @@ +package discovery + +import ( + "context" + "testing" + "time" + + bhost "github.com/libp2p/go-libp2p-blankhost" + "github.com/libp2p/go-libp2p-core/discovery" + "github.com/libp2p/go-libp2p-core/peer" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" +) + +type delayedDiscovery struct { + disc discovery.Discovery + delay time.Duration +} + +func (d *delayedDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { + return d.disc.Advertise(ctx, ns, opts...) +} + +func (d *delayedDiscovery) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { + dch, err := d.disc.FindPeers(ctx, ns, opts...) + if err != nil { + return nil, err + } + + ch := make(chan peer.AddrInfo, 32) + go func() { + defer close(ch) + for ai := range dch { + ch <- ai + time.Sleep(d.delay) + } + }() + + return ch, nil +} + +func assertNumPeers(t *testing.T, ctx context.Context, d discovery.Discovery, ns string, count int) { + t.Helper() + peerCh, err := d.FindPeers(ctx, ns, discovery.Limit(10)) + if err != nil { + t.Fatal(err) + } + + peerset := make(map[peer.ID]struct{}) + for p := range peerCh { + peerset[p.ID] = struct{}{} + } + + if len(peerset) != count { + t.Fatalf("Was supposed to find %d, found %d instead", count, len(peerset)) + } +} + +func TestBackoffDiscoverySingleBackoff(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + discServer := newDiscoveryServer() + + h1 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + h2 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + d1 := &mockDiscoveryClient{h1, discServer} + d2 := &mockDiscoveryClient{h2, discServer} + + bkf := NewExponentialBackoff(time.Millisecond*100, time.Second*10, NoJitter, + time.Millisecond*100, 2.5, 0, nil) + dCache, err := NewBackoffDiscovery(d1, bkf) + if err != nil { + t.Fatal(err) + } + + const ns = "test" + + // try adding a peer then find it + d1.Advertise(ctx, ns, discovery.TTL(time.Hour)) + assertNumPeers(t, ctx, dCache, ns, 1) + + // add a new peer and make sure it is still hidden by the caching layer + d2.Advertise(ctx, ns, discovery.TTL(time.Hour)) + assertNumPeers(t, ctx, dCache, ns, 1) + + // wait for cache to expire and check for the new peer + time.Sleep(time.Millisecond * 110) + assertNumPeers(t, ctx, dCache, ns, 2) +} + +func TestBackoffDiscoveryMultipleBackoff(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + discServer := newDiscoveryServer() + + h1 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + h2 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + d1 := &mockDiscoveryClient{h1, discServer} + d2 := &mockDiscoveryClient{h2, discServer} + + // Startup delay is 0ms. First backoff after finding data is 100ms, second backoff is 250ms. + bkf := NewExponentialBackoff(time.Millisecond*100, time.Second*10, NoJitter, + time.Millisecond*100, 2.5, 0, nil) + dCache, err := NewBackoffDiscovery(d1, bkf) + if err != nil { + t.Fatal(err) + } + + const ns = "test" + + // try adding a peer then find it + d1.Advertise(ctx, ns, discovery.TTL(time.Hour)) + assertNumPeers(t, ctx, dCache, ns, 1) + + // wait a little to make sure the extra request doesn't modify the backoff + time.Sleep(time.Millisecond * 50) //50 < 100 + assertNumPeers(t, ctx, dCache, ns, 1) + + // wait for backoff to expire and check if we increase it + time.Sleep(time.Millisecond * 60) // 50+60 > 100 + assertNumPeers(t, ctx, dCache, ns, 1) + + d2.Advertise(ctx, ns, discovery.TTL(time.Millisecond*400)) + + time.Sleep(time.Millisecond * 150) //150 < 250 + assertNumPeers(t, ctx, dCache, ns, 1) + + time.Sleep(time.Millisecond * 150) //150 + 150 > 250 + assertNumPeers(t, ctx, dCache, ns, 2) + + // check that the backoff has been reset + // also checks that we can decrease our peer count (i.e. not just growing a set) + time.Sleep(time.Millisecond * 110) //110 > 100, also 150+150+110>400 + assertNumPeers(t, ctx, dCache, ns, 1) +} + +func TestBackoffDiscoverySimultaneousQuery(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + discServer := newDiscoveryServer() + + // Testing with n larger than most internal buffer sizes (32) + n := 40 + advertisers := make([]discovery.Discovery, n) + + for i := 0; i < n; i++ { + h := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + advertisers[i] = &mockDiscoveryClient{h, discServer} + } + + d1 := &delayedDiscovery{advertisers[0], time.Millisecond * 10} + + bkf := NewFixedBackoff(time.Millisecond * 200) + dCache, err := NewBackoffDiscovery(d1, bkf) + if err != nil { + t.Fatal(err) + } + + const ns = "test" + + for _, a := range advertisers { + if _, err := a.Advertise(ctx, ns, discovery.TTL(time.Hour)); err != nil { + t.Fatal(err) + } + } + + ch1, err := dCache.FindPeers(ctx, ns) + if err != nil { + t.Fatal(err) + } + + _ = <-ch1 + ch2, err := dCache.FindPeers(ctx, ns) + if err != nil { + t.Fatal(err) + } + + szCh2 := 0 + for ai := range ch2 { + _ = ai + szCh2++ + } + + szCh1 := 1 + for _ = range ch1 { + szCh1++ + } + + if szCh1 != n && szCh2 != n { + t.Fatalf("Channels returned %d, %d elements instead of %d", szCh1, szCh2, n) + } +} diff --git a/backoffconnector.go b/backoffconnector.go new file mode 100644 index 0000000..6f6c58f --- /dev/null +++ b/backoffconnector.go @@ -0,0 +1,95 @@ +package discovery + +import ( + "context" + lru "github.com/hashicorp/golang-lru" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" +) + +// BackoffConnector is a utility to connect to peers, but only if we have not recently tried connecting to them already +type BackoffConnector struct { + cache *lru.TwoQueueCache + host host.Host + connTryDur time.Duration + backoff BackoffFactory + mux sync.Mutex +} + +// NewBackoffConnector creates a utility to connect to peers, but only if we have not recently tried connecting to them already +// cacheSize is the size of a TwoQueueCache +// connectionTryDuration is how long we attempt to connect to a peer before giving up +// backoff describes the strategy used to decide how long to backoff after previously attempting to connect to a peer +func NewBackoffConnector(h host.Host, cacheSize int, connectionTryDuration time.Duration, backoff BackoffFactory) (*BackoffConnector, error) { + cache, err := lru.New2Q(cacheSize) + if err != nil { + return nil, err + } + + return &BackoffConnector{ + cache: cache, + host: h, + connTryDur: connectionTryDuration, + backoff: backoff, + }, nil +} + +type connCacheData struct { + nextTry time.Time + strat BackoffStrategy +} + +// Connect attempts to connect to the peers passed in by peerCh. Will not connect to peers if they are within the backoff period. +// As Connect will attempt to dial peers as soon as it learns about them, the caller should try to keep the number, +// and rate, of inbound peers manageable. +func (c *BackoffConnector) Connect(ctx context.Context, peerCh <-chan peer.AddrInfo) { + for { + select { + case pi, ok := <-peerCh: + if !ok { + return + } + + if pi.ID == c.host.ID() || pi.ID == "" { + continue + } + + c.mux.Lock() + val, ok := c.cache.Get(pi.ID) + var cachedPeer *connCacheData + if ok { + tv := val.(*connCacheData) + now := time.Now() + if now.Before(tv.nextTry) { + c.mux.Unlock() + continue + } + + tv.nextTry = now.Add(tv.strat.Delay()) + } else { + cachedPeer = &connCacheData{strat: c.backoff()} + cachedPeer.nextTry = time.Now().Add(cachedPeer.strat.Delay()) + c.cache.Add(pi.ID, cachedPeer) + } + c.mux.Unlock() + + go func(pi peer.AddrInfo) { + ctx, cancel := context.WithTimeout(ctx, c.connTryDur) + defer cancel() + + err := c.host.Connect(ctx, pi) + if err != nil { + log.Debugf("Error connecting to pubsub peer %s: %s", pi.ID, err.Error()) + return + } + }(pi) + + case <-ctx.Done(): + log.Infof("discovery: backoff connector context error %v", ctx.Err()) + return + } + } +} diff --git a/backoffconnector_test.go b/backoffconnector_test.go new file mode 100644 index 0000000..e88a9c2 --- /dev/null +++ b/backoffconnector_test.go @@ -0,0 +1,108 @@ +package discovery + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + bhost "github.com/libp2p/go-libp2p-blankhost" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" +) + +type maxDialHost struct { + host.Host + + mux sync.Mutex + timesDialed map[peer.ID]int + maxTimesToDial map[peer.ID]int +} + +func (h *maxDialHost) Connect(ctx context.Context, ai peer.AddrInfo) error { + pid := ai.ID + + h.mux.Lock() + defer h.mux.Unlock() + numDials := h.timesDialed[pid] + numDials += 1 + h.timesDialed[pid] = numDials + + if maxDials, ok := h.maxTimesToDial[pid]; ok && numDials > maxDials { + return fmt.Errorf("should not be dialing peer %s", pid.String()) + } + + return h.Host.Connect(ctx, ai) +} + +func getNetHosts(t *testing.T, ctx context.Context, n int) []host.Host { + var out []host.Host + + for i := 0; i < n; i++ { + netw := swarmt.GenSwarm(t, ctx) + h := bhost.NewBlankHost(netw) + out = append(out, h) + } + + return out +} + +func loadCh(peers []host.Host) <-chan peer.AddrInfo { + ch := make(chan peer.AddrInfo, len(peers)) + for _, p := range peers { + ch <- p.Peerstore().PeerInfo(p.ID()) + } + close(ch) + return ch +} + +func TestBackoffConnector(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 5) + primary := &maxDialHost{ + Host: hosts[0], + mux: sync.Mutex{}, + timesDialed: make(map[peer.ID]int), + maxTimesToDial: map[peer.ID]int{ + hosts[1].ID(): 1, + hosts[2].ID(): 2, + }, + } + + bc, err := NewBackoffConnector(primary, 10, time.Minute, NewFixedBackoff(time.Millisecond*1500)) + if err != nil { + t.Fatal(err) + } + + bc.Connect(ctx, loadCh(hosts)) + + time.Sleep(time.Millisecond * 100) + if expected, actual := len(hosts)-1, len(primary.Network().Conns()); actual != expected { + t.Fatalf("wrong number of connections. expected %d, actual %d", expected, actual) + } + + for _, c := range primary.Network().Conns() { + c.Close() + } + + for len(primary.Network().Conns()) > 0 { + time.Sleep(time.Millisecond * 100) + } + + bc.Connect(ctx, loadCh(hosts)) + if numConns := len(primary.Network().Conns()); numConns != 0 { + t.Fatal("shouldn't be connected to any peers") + } + + time.Sleep(time.Millisecond * 1600) + bc.Connect(ctx, loadCh(hosts)) + + time.Sleep(time.Millisecond * 100) + if expected, actual := len(hosts)-2, len(primary.Network().Conns()); actual != expected { + t.Fatalf("wrong number of connections. expected %d, actual %d", expected, actual) + } +} diff --git a/go.mod b/go.mod index f8ea495..64912bf 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,14 @@ module github.com/libp2p/go-libp2p-discovery require ( + github.com/hashicorp/golang-lru v0.5.1 github.com/ipfs/go-cid v0.0.2 github.com/ipfs/go-log v0.0.1 github.com/libp2p/go-libp2p-blankhost v0.1.1 github.com/libp2p/go-libp2p-core v0.0.1 + github.com/libp2p/go-libp2p-peerstore v0.1.0 github.com/libp2p/go-libp2p-swarm v0.1.0 github.com/multiformats/go-multihash v0.0.5 ) + +go 1.13 diff --git a/go.sum b/go.sum index eed16ef..effd089 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,7 @@ github.com/gxed/hashland/keccakpg v0.0.1 h1:wrk3uMNaMxbXiHibbPO4S0ymqJMm41WiudyF github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfmKFJ6tItnhQ67kU= github.com/gxed/hashland/murmur3 v0.0.1 h1:SheiaIt0sda5K+8FLz952/1iWS9zrnKsEJaOJu4ZbSc= github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48= +github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/mocks_test.go b/mocks_test.go new file mode 100644 index 0000000..11f62b0 --- /dev/null +++ b/mocks_test.go @@ -0,0 +1,112 @@ +package discovery + +import ( + "context" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/discovery" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" +) + +type mockDiscoveryServer struct { + mx sync.Mutex + db map[string]map[peer.ID]*discoveryRegistration +} + +type discoveryRegistration struct { + info peer.AddrInfo + expiration time.Time +} + +func newDiscoveryServer() *mockDiscoveryServer { + return &mockDiscoveryServer{ + db: make(map[string]map[peer.ID]*discoveryRegistration), + } +} + +func (s *mockDiscoveryServer) Advertise(ns string, info peer.AddrInfo, ttl time.Duration) (time.Duration, error) { + s.mx.Lock() + defer s.mx.Unlock() + + peers, ok := s.db[ns] + if !ok { + peers = make(map[peer.ID]*discoveryRegistration) + s.db[ns] = peers + } + peers[info.ID] = &discoveryRegistration{info, time.Now().Add(ttl)} + return ttl, nil +} + +func (s *mockDiscoveryServer) FindPeers(ns string, limit int) (<-chan peer.AddrInfo, error) { + s.mx.Lock() + defer s.mx.Unlock() + + peers, ok := s.db[ns] + if !ok || len(peers) == 0 { + emptyCh := make(chan peer.AddrInfo) + close(emptyCh) + return emptyCh, nil + } + + count := len(peers) + if limit != 0 && count > limit { + count = limit + } + + iterTime := time.Now() + ch := make(chan peer.AddrInfo, count) + numSent := 0 + for p, reg := range peers { + if numSent == count { + break + } + if iterTime.After(reg.expiration) { + delete(peers, p) + continue + } + + numSent++ + ch <- reg.info + } + close(ch) + + return ch, nil +} + +func (s *mockDiscoveryServer) hasPeerRecord(ns string, pid peer.ID) bool { + s.mx.Lock() + defer s.mx.Unlock() + + if peers, ok := s.db[ns]; ok { + _, ok := peers[pid] + return ok + } + return false +} + +type mockDiscoveryClient struct { + host host.Host + server *mockDiscoveryServer +} + +func (d *mockDiscoveryClient) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return 0, err + } + + return d.server.Advertise(ns, *host.InfoFromHost(d.host), options.Ttl) +} + +func (d *mockDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return nil, err + } + + return d.server.FindPeers(ns, options.Limit) +} diff --git a/routing_test.go b/routing_test.go index c07a5d6..1a95682 100644 --- a/routing_test.go +++ b/routing_test.go @@ -71,107 +71,6 @@ func (m *mockRouting) FindProvidersAsync(ctx context.Context, cid cid.Cid, limit return ch } -type mockDiscoveryServer struct { - mx sync.Mutex - db map[string]map[peer.ID]*discoveryRegistration -} - -type discoveryRegistration struct { - info peer.AddrInfo - expiration time.Time -} - -func newDiscoveryServer() *mockDiscoveryServer { - return &mockDiscoveryServer{ - db: make(map[string]map[peer.ID]*discoveryRegistration), - } -} - -func (s *mockDiscoveryServer) Advertise(ns string, info peer.AddrInfo, ttl time.Duration) (time.Duration, error) { - s.mx.Lock() - defer s.mx.Unlock() - - peers, ok := s.db[ns] - if !ok { - peers = make(map[peer.ID]*discoveryRegistration) - s.db[ns] = peers - } - peers[info.ID] = &discoveryRegistration{info, time.Now().Add(ttl)} - return ttl, nil -} - -func (s *mockDiscoveryServer) FindPeers(ns string, limit int) (<-chan peer.AddrInfo, error) { - s.mx.Lock() - defer s.mx.Unlock() - - peers, ok := s.db[ns] - if !ok || len(peers) == 0 { - emptyCh := make(chan peer.AddrInfo) - close(emptyCh) - return emptyCh, nil - } - - count := len(peers) - if limit != 0 && count > limit { - count = limit - } - - iterTime := time.Now() - ch := make(chan peer.AddrInfo, count) - numSent := 0 - for p, reg := range peers { - if numSent == count { - break - } - if iterTime.After(reg.expiration) { - delete(peers, p) - continue - } - - numSent++ - ch <- reg.info - } - close(ch) - - return ch, nil -} - -func (s *mockDiscoveryServer) hasPeerRecord(ns string, pid peer.ID) bool { - s.mx.Lock() - defer s.mx.Unlock() - - if peers, ok := s.db[ns]; ok { - _, ok := peers[pid] - return ok - } - return false -} - -type mockDiscoveryClient struct { - host host.Host - server *mockDiscoveryServer -} - -func (d *mockDiscoveryClient) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { - var options discovery.Options - err := options.Apply(opts...) - if err != nil { - return 0, err - } - - return d.server.Advertise(ns, *host.InfoFromHost(d.host), options.Ttl) -} - -func (d *mockDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { - var options discovery.Options - err := options.Apply(opts...) - if err != nil { - return nil, err - } - - return d.server.FindPeers(ns, options.Limit) -} - func TestRoutingDiscovery(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()