From b9cb861a98b464b1e9a4f697b6c0c6ad72116505 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 17 Oct 2024 08:38:43 +0530 Subject: [PATCH] autonat: fix interaction with autorelay (#2967) * autonat: fix interaction with autorelay * Fix race in test * Use deadline from context if available for DialBack * Return hasNewAddrs correctly * nit: cleanup contains check * Shuffle peers * nits * Change comment to indicate the bug * holepuncher: pass address function in constructor (#2979) * holepunch: pass address function in constructor * nit * Remove getPublicAddrs --------- Co-authored-by: Marco Munizaga * Make a copy of the multiaddr slice in Addrs() --------- Co-authored-by: Marco Munizaga --- config/config.go | 41 ++--- core/host/host.go | 2 +- p2p/host/autonat/autonat.go | 209 ++++++++++++----------- p2p/host/autonat/autonat_test.go | 42 ++--- p2p/host/autonat/client.go | 9 +- p2p/host/autonat/options.go | 2 + p2p/host/autorelay/autorelay.go | 35 ++-- p2p/host/autorelay/relay.go | 2 + p2p/host/basic/basic_host.go | 52 +++--- p2p/protocol/holepunch/holepunch_test.go | 15 +- p2p/protocol/holepunch/holepuncher.go | 23 +-- p2p/protocol/holepunch/svc.go | 80 +++------ p2p/protocol/holepunch/util.go | 9 +- 13 files changed, 249 insertions(+), 272 deletions(-) diff --git a/config/config.go b/config/config.go index c7af9530f3..07bee93c60 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "slices" "time" "github.com/libp2p/go-libp2p/core/connmgr" @@ -430,16 +431,6 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B if err != nil { return nil, err } - if cfg.Relay { - // If we've enabled the relay, we should filter out relay - // addresses by default. - // - // TODO: We shouldn't be doing this here. - originalAddrFactory := h.AddrsFactory - h.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { - return originalAddrFactory(autorelay.Filter(addrs)) - } - } return h, nil } @@ -512,17 +503,8 @@ func (cfg *Config) NewNode() (host.Host, error) { ) } - // originalAddrFactory is the AddrFactory before it's modified by autorelay - // we need this for checking reachability via autonat - originalAddrFactory := func(addrs []ma.Multiaddr) []ma.Multiaddr { - return addrs - } - // enable autorelay fxopts = append(fxopts, - fx.Invoke(func(h *bhost.BasicHost) { - originalAddrFactory = h.AddrsFactory - }), fx.Invoke(func(h *bhost.BasicHost, lifecycle fx.Lifecycle) error { if cfg.EnableAutoRelay { if !cfg.DisableMetrics { @@ -559,7 +541,7 @@ func (cfg *Config) NewNode() (host.Host, error) { return nil, err } - if err := cfg.addAutoNAT(bh, originalAddrFactory); err != nil { + if err := cfg.addAutoNAT(bh); err != nil { app.Stop(context.Background()) if cfg.Routing != nil { rh.Close() @@ -575,11 +557,20 @@ func (cfg *Config) NewNode() (host.Host, error) { return &closableBasicHost{App: app, BasicHost: bh}, nil } -func (cfg *Config) addAutoNAT(h *bhost.BasicHost, addrF AddrsFactory) error { +func (cfg *Config) addAutoNAT(h *bhost.BasicHost) error { + // Only use public addresses for autonat + addrFunc := func() []ma.Multiaddr { + return slices.DeleteFunc(h.AllAddrs(), func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) + } + if cfg.AddrsFactory != nil { + addrFunc = func() []ma.Multiaddr { + return slices.DeleteFunc( + cfg.AddrsFactory(h.AllAddrs()), + func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) + } + } autonatOpts := []autonat.Option{ - autonat.UsingAddresses(func() []ma.Multiaddr { - return addrF(h.AllAddrs()) - }), + autonat.UsingAddresses(addrFunc), } if !cfg.DisableMetrics { autonatOpts = append(autonatOpts, autonat.WithMetricsTracer( @@ -662,7 +653,7 @@ func (cfg *Config) addAutoNAT(h *bhost.BasicHost, addrF AddrsFactory) error { autonat, err := autonat.New(h, autonatOpts...) if err != nil { - return fmt.Errorf("cannot enable autorelay; autonat failed to start: %v", err) + return fmt.Errorf("autonat init failed: %w", err) } h.SetAutoNat(autonat) return nil diff --git a/core/host/host.go b/core/host/host.go index 7990f7f456..0a8dbe4b0c 100644 --- a/core/host/host.go +++ b/core/host/host.go @@ -41,7 +41,7 @@ type Host interface { // given peer.ID. Connect will absorb the addresses in pi into its internal // peerstore. If there is not an active connection, Connect will issue a // h.Network.Dial, and block until a connection is open, or an error is - // returned. // TODO: Relay + NAT. + // returned. Connect(ctx context.Context, pi peer.AddrInfo) error // SetStreamHandler sets the protocol handler on the Host's Mux. diff --git a/p2p/host/autonat/autonat.go b/p2p/host/autonat/autonat.go index 479f31ecfb..6d3c9e242f 100644 --- a/p2p/host/autonat/autonat.go +++ b/p2p/host/autonat/autonat.go @@ -3,6 +3,7 @@ package autonat import ( "context" "math/rand" + "slices" "sync/atomic" "time" @@ -33,6 +34,8 @@ type AmbientAutoNAT struct { inboundConn chan network.Conn dialResponses chan error + // Used when testing the autonat service + observations chan network.Reachability // status is an autoNATResult reflecting current status. status atomic.Pointer[network.Reachability] // Reflects the confidence on of the NATStatus being private, as a single @@ -40,11 +43,12 @@ type AmbientAutoNAT struct { // If it is <3, then multiple autoNAT peers may be contacted for dialback // If only a single autoNAT peer is known, then the confidence increases // for each failure until it reaches 3. - confidence int - lastInbound time.Time - lastProbeTry time.Time - lastProbe time.Time - recentProbes map[peer.ID]time.Time + confidence int + lastInbound time.Time + lastProbe time.Time + recentProbes map[peer.ID]time.Time + pendingProbes int + ourAddrs map[string]struct{} service *autoNATService @@ -70,7 +74,11 @@ func New(h host.Host, options ...Option) (AutoNAT, error) { return nil, err } if conf.addressFunc == nil { - conf.addressFunc = h.Addrs + if aa, ok := h.(interface{ AllAddrs() []ma.Multiaddr }); ok { + conf.addressFunc = aa.AllAddrs + } else { + conf.addressFunc = h.Addrs + } } for _, o := range options { @@ -108,10 +116,12 @@ func New(h host.Host, options ...Option) (AutoNAT, error) { config: conf, inboundConn: make(chan network.Conn, 5), dialResponses: make(chan error, 1), + observations: make(chan network.Reachability, 1), emitReachabilityChanged: emitReachabilityChanged, service: service, recentProbes: make(map[peer.ID]time.Time), + ourAddrs: make(map[string]struct{}), } reachability := network.ReachabilityUnknown as.status.Store(&reachability) @@ -125,7 +135,6 @@ func New(h host.Host, options ...Option) (AutoNAT, error) { } as.subscriber = subscriber - h.Network().Notify(as) go as.background() return as, nil @@ -165,117 +174,126 @@ func (as *AmbientAutoNAT) background() { defer as.subscriber.Close() defer as.emitReachabilityChanged.Close() + // Fallback timer to update address in case EvtLocalAddressesUpdated is not emitted. + // TODO: The event not emitting properly is a bug. This is a workaround. + addrChangeTicker := time.NewTicker(30 * time.Minute) + defer addrChangeTicker.Stop() + timer := time.NewTimer(delay) defer timer.Stop() timerRunning := true - retryProbe := false + forceProbe := false for { select { - // new inbound connection. case conn := <-as.inboundConn: localAddrs := as.host.Addrs() if manet.IsPublicAddr(conn.RemoteMultiaddr()) && !ipInList(conn.RemoteMultiaddr(), localAddrs) { as.lastInbound = time.Now() } - + case <-addrChangeTicker.C: + // schedule a new probe if addresses have changed case e := <-subChan: switch e := e.(type) { - case event.EvtLocalAddressesUpdated: - // On local address update, reduce confidence from maximum so that we schedule - // the next probe sooner - if as.confidence == maxConfidence { - as.confidence-- - } case event.EvtPeerIdentificationCompleted: - if s, err := as.host.Peerstore().SupportsProtocols(e.Peer, AutoNATProto); err == nil && len(s) > 0 { - currentStatus := *as.status.Load() - if currentStatus == network.ReachabilityUnknown { - as.tryProbe(e.Peer) - } + if proto, err := as.host.Peerstore().SupportsProtocols(e.Peer, AutoNATProto); err == nil && len(proto) > 0 { + forceProbe = true } + case event.EvtLocalAddressesUpdated: + // schedule a new probe if addresses have changed default: log.Errorf("unknown event type: %T", e) } - - // probe finished. + case obs := <-as.observations: + as.recordObservation(obs) + continue case err, ok := <-as.dialResponses: if !ok { return } + as.pendingProbes-- if IsDialRefused(err) { - retryProbe = true + forceProbe = true } else { as.handleDialResponse(err) } case <-timer.C: + timerRunning = false + forceProbe = false + // Update the last probe time. We use it to ensure + // that we don't spam the peerstore. + as.lastProbe = time.Now() peer := as.getPeerToProbe() as.tryProbe(peer) - timerRunning = false - retryProbe = false case <-as.ctx.Done(): return } + // On address update, reduce confidence from maximum so that we schedule + // the next probe sooner + hasNewAddr := as.checkAddrs() + if hasNewAddr && as.confidence == maxConfidence { + as.confidence-- + } - // Drain the timer channel if it hasn't fired in preparation for Resetting it. if timerRunning && !timer.Stop() { <-timer.C } - timer.Reset(as.scheduleProbe(retryProbe)) + timer.Reset(as.scheduleProbe(forceProbe)) timerRunning = true } } -func (as *AmbientAutoNAT) cleanupRecentProbes() { - fixedNow := time.Now() - for k, v := range as.recentProbes { - if fixedNow.Sub(v) > as.throttlePeerPeriod { - delete(as.recentProbes, k) +func (as *AmbientAutoNAT) checkAddrs() (hasNewAddr bool) { + currentAddrs := as.addressFunc() + hasNewAddr = slices.ContainsFunc(currentAddrs, func(a ma.Multiaddr) bool { + _, ok := as.ourAddrs[string(a.Bytes())] + return !ok + }) + clear(as.ourAddrs) + for _, a := range currentAddrs { + if !manet.IsPublicAddr(a) { + continue } + as.ourAddrs[string(a.Bytes())] = struct{}{} } + return hasNewAddr } // scheduleProbe calculates when the next probe should be scheduled for. -func (as *AmbientAutoNAT) scheduleProbe(retryProbe bool) time.Duration { - // Our baseline is a probe every 'AutoNATRefreshInterval' - // This is modulated by: - // * if we are in an unknown state, have low confidence, or we want to retry because a probe was refused that - // should drop to 'AutoNATRetryInterval' - // * recent inbound connections (implying continued connectivity) should decrease the retry when public - // * recent inbound connections when not public mean we should try more actively to see if we're public. - fixedNow := time.Now() +func (as *AmbientAutoNAT) scheduleProbe(forceProbe bool) time.Duration { + now := time.Now() currentStatus := *as.status.Load() - - nextProbe := fixedNow - // Don't look for peers in the peer store more than once per second. - if !as.lastProbeTry.IsZero() { - backoff := as.lastProbeTry.Add(time.Second) - if backoff.After(nextProbe) { - nextProbe = backoff - } + nextProbeAfter := as.config.refreshInterval + receivedInbound := as.lastInbound.After(as.lastProbe) + switch { + case forceProbe && currentStatus == network.ReachabilityUnknown: + // retry very quicky if forceProbe is true *and* we don't know our reachability + // limit all peers fetch from peerstore to 1 per second. + nextProbeAfter = 2 * time.Second + nextProbeAfter = 2 * time.Second + case currentStatus == network.ReachabilityUnknown, + as.confidence < maxConfidence, + currentStatus != network.ReachabilityPublic && receivedInbound: + // Retry quickly in case: + // 1. Our reachability is Unknown + // 2. We don't have enough confidence in our reachability. + // 3. We're private but we received an inbound connection. + nextProbeAfter = as.config.retryInterval + case currentStatus == network.ReachabilityPublic && receivedInbound: + // We are public and we received an inbound connection recently, + // wait a little longer + nextProbeAfter *= 2 + nextProbeAfter = min(nextProbeAfter, maxRefreshInterval) } - if !as.lastProbe.IsZero() { - untilNext := as.config.refreshInterval - if retryProbe { - untilNext = as.config.retryInterval - } else if currentStatus == network.ReachabilityUnknown { - untilNext = as.config.retryInterval - } else if as.confidence < maxConfidence { - untilNext = as.config.retryInterval - } else if currentStatus == network.ReachabilityPublic && as.lastInbound.After(as.lastProbe) { - untilNext *= 2 - } else if currentStatus != network.ReachabilityPublic && as.lastInbound.After(as.lastProbe) { - untilNext /= 5 - } - - if as.lastProbe.Add(untilNext).After(nextProbe) { - nextProbe = as.lastProbe.Add(untilNext) - } + nextProbeTime := as.lastProbe.Add(nextProbeAfter) + if nextProbeTime.Before(now) { + nextProbeTime = now } if as.metricsTracer != nil { - as.metricsTracer.NextProbeTime(nextProbe) + as.metricsTracer.NextProbeTime(nextProbeTime) } - return nextProbe.Sub(fixedNow) + + return nextProbeTime.Sub(now) } // handleDialResponse updates the current status based on dial response. @@ -354,28 +372,14 @@ func (as *AmbientAutoNAT) recordObservation(observation network.Reachability) { } } -func (as *AmbientAutoNAT) tryProbe(p peer.ID) bool { - as.lastProbeTry = time.Now() - if p.Validate() != nil { - return false - } - - if lastTime, ok := as.recentProbes[p]; ok { - if time.Since(lastTime) < as.throttlePeerPeriod { - return false - } +func (as *AmbientAutoNAT) tryProbe(p peer.ID) { + if p == "" || as.pendingProbes > 5 { + return } - as.cleanupRecentProbes() - info := as.host.Peerstore().PeerInfo(p) - - if !as.config.dialPolicy.skipPeer(info.Addrs) { - as.recentProbes[p] = time.Now() - as.lastProbe = time.Now() - go as.probe(&info) - return true - } - return false + as.recentProbes[p] = time.Now() + as.pendingProbes++ + go as.probe(&info) } func (as *AmbientAutoNAT) probe(pi *peer.AddrInfo) { @@ -399,7 +403,19 @@ func (as *AmbientAutoNAT) getPeerToProbe() peer.ID { return "" } - candidates := make([]peer.ID, 0, len(peers)) + // clean old probes + fixedNow := time.Now() + for k, v := range as.recentProbes { + if fixedNow.Sub(v) > as.throttlePeerPeriod { + delete(as.recentProbes, k) + } + } + + // Shuffle peers + for n := len(peers); n > 0; n-- { + randIndex := rand.Intn(n) + peers[n-1], peers[randIndex] = peers[randIndex], peers[n-1] + } for _, p := range peers { info := as.host.Peerstore().PeerInfo(p) @@ -408,24 +424,13 @@ func (as *AmbientAutoNAT) getPeerToProbe() peer.ID { continue } - // Exclude peers in backoff. - if lastTime, ok := as.recentProbes[p]; ok { - if time.Since(lastTime) < as.throttlePeerPeriod { - continue - } - } - if as.config.dialPolicy.skipPeer(info.Addrs) { continue } - candidates = append(candidates, p) - } - - if len(candidates) == 0 { - return "" + return p } - return candidates[rand.Intn(len(candidates))] + return "" } func (as *AmbientAutoNAT) Close() error { diff --git a/p2p/host/autonat/autonat_test.go b/p2p/host/autonat/autonat_test.go index d8cfcc51e4..6a5768cd5a 100644 --- a/p2p/host/autonat/autonat_test.go +++ b/p2p/host/autonat/autonat_test.go @@ -15,6 +15,7 @@ import ( "github.com/libp2p/go-msgio/pbio" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -223,7 +224,7 @@ func TestAutoNATIncomingEvents(t *testing.T) { require.Eventually(t, func() bool { return an.Status() != network.ReachabilityUnknown - }, 500*time.Millisecond, 10*time.Millisecond, "Expected probe due to identification of autonat service") + }, 5*time.Second, 100*time.Millisecond, "Expected probe due to identification of autonat service") } func TestAutoNATDialRefused(t *testing.T) { @@ -258,6 +259,10 @@ func TestAutoNATDialRefused(t *testing.T) { close(done) } +func recordObservation(an *AmbientAutoNAT, status network.Reachability) { + an.observations <- status +} + func TestAutoNATObservationRecording(t *testing.T) { hs := makeAutoNATServicePublic(t) defer hs.Close() @@ -271,39 +276,34 @@ func TestAutoNATObservationRecording(t *testing.T) { t.Fatalf("failed to subscribe to event EvtLocalRoutabilityPublic, err=%s", err) } - an.recordObservation(network.ReachabilityPublic) - if an.Status() != network.ReachabilityPublic { - t.Fatalf("failed to transition to public.") + expectStatus := func(expected network.Reachability, msg string, args ...any) { + require.EventuallyWithTf(t, func(collect *assert.CollectT) { + assert.Equal(collect, expected, an.Status()) + }, 2*time.Second, 100*time.Millisecond, msg, args...) } + recordObservation(an, network.ReachabilityPublic) + expectStatus(network.ReachabilityPublic, "failed to transition to public.") expectEvent(t, s, network.ReachabilityPublic, 3*time.Second) // a single recording should have confidence still at 0, and transition to private quickly. - an.recordObservation(network.ReachabilityPrivate) - if an.Status() != network.ReachabilityPrivate { - t.Fatalf("failed to transition to private.") - } + recordObservation(an, network.ReachabilityPrivate) + expectStatus(network.ReachabilityPrivate, "failed to transition to private.") expectEvent(t, s, network.ReachabilityPrivate, 3*time.Second) // stronger public confidence should be harder to undo. - an.recordObservation(network.ReachabilityPublic) - an.recordObservation(network.ReachabilityPublic) - if an.Status() != network.ReachabilityPublic { - t.Fatalf("failed to transition to public.") - } + recordObservation(an, network.ReachabilityPublic) + recordObservation(an, network.ReachabilityPublic) + expectStatus(network.ReachabilityPublic, "failed to transition to public.") expectEvent(t, s, network.ReachabilityPublic, 3*time.Second) - an.recordObservation(network.ReachabilityPrivate) - if an.Status() != network.ReachabilityPublic { - t.Fatalf("too-extreme private transition.") - } + recordObservation(an, network.ReachabilityPrivate) + expectStatus(network.ReachabilityPublic, "too-extreme private transition.") // Don't emit events if reachability hasn't changed - an.recordObservation(network.ReachabilityPublic) - if an.Status() != network.ReachabilityPublic { - t.Fatalf("reachability should stay public") - } + recordObservation(an, network.ReachabilityPublic) + expectStatus(network.ReachabilityPublic, "reachability should stay public") select { case <-s.Out(): t.Fatal("received event without state transition") diff --git a/p2p/host/autonat/client.go b/p2p/host/autonat/client.go index fa0e03bc51..7f419a72ff 100644 --- a/p2p/host/autonat/client.go +++ b/p2p/host/autonat/client.go @@ -53,7 +53,14 @@ func (c *client) DialBack(ctx context.Context, p peer.ID) error { } defer s.Scope().ReleaseMemory(maxMsgSize) - s.SetDeadline(time.Now().Add(streamTimeout)) + deadline := time.Now().Add(streamTimeout) + if ctxDeadline, ok := ctx.Deadline(); ok { + if ctxDeadline.Before(deadline) { + deadline = ctxDeadline + } + } + + s.SetDeadline(deadline) // Might as well just reset the stream. Once we get to this point, we // don't care about being nice. defer s.Close() diff --git a/p2p/host/autonat/options.go b/p2p/host/autonat/options.go index dec62c5f1d..b378da348d 100644 --- a/p2p/host/autonat/options.go +++ b/p2p/host/autonat/options.go @@ -51,6 +51,8 @@ var defaults = func(c *config) error { return nil } +const maxRefreshInterval = 24 * time.Hour + // EnableService specifies that AutoNAT should be allowed to run a NAT service to help // other peers determine their own NAT status. The provided Network should not be the // default network/dialer of the host passed to `New`, as the NAT system will need to diff --git a/p2p/host/autorelay/autorelay.go b/p2p/host/autorelay/autorelay.go index 5900798533..b31302098d 100644 --- a/p2p/host/autorelay/autorelay.go +++ b/p2p/host/autorelay/autorelay.go @@ -29,8 +29,7 @@ type AutoRelay struct { relayFinder *relayFinder - host host.Host - addrsF basic.AddrsFactory + host host.Host metricsTracer MetricsTracer } @@ -38,7 +37,6 @@ type AutoRelay struct { func NewAutoRelay(bhost *basic.BasicHost, opts ...Option) (*AutoRelay, error) { r := &AutoRelay{ host: bhost, - addrsF: bhost.AddrsFactory, status: network.ReachabilityUnknown, } conf := defaultConfig @@ -51,7 +49,22 @@ func NewAutoRelay(bhost *basic.BasicHost, opts ...Option) (*AutoRelay, error) { r.conf = &conf r.relayFinder = newRelayFinder(bhost, conf.peerSource, &conf) r.metricsTracer = &wrappedMetricsTracer{conf.metricsTracer} - bhost.AddrsFactory = r.hostAddrs + + // Update the host address factory to use autorelay addresses if we're private + // + // TODO: Don't update host address factory. Instead send our relay addresses on the eventbus. + // The host can decide how to handle those. + addrF := bhost.AddrsFactory + bhost.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { + addrs = addrF(addrs) + r.mx.Lock() + defer r.mx.Unlock() + + if r.status != network.ReachabilityPrivate { + return addrs + } + return r.relayFinder.relayAddrs(addrs) + } return r, nil } @@ -103,20 +116,6 @@ func (r *AutoRelay) background() { } } -func (r *AutoRelay) hostAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - return r.relayAddrs(r.addrsF(addrs)) -} - -func (r *AutoRelay) relayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - r.mx.Lock() - defer r.mx.Unlock() - - if r.status != network.ReachabilityPrivate { - return addrs - } - return r.relayFinder.relayAddrs(addrs) -} - func (r *AutoRelay) Close() error { r.ctxCancel() err := r.relayFinder.Stop() diff --git a/p2p/host/autorelay/relay.go b/p2p/host/autorelay/relay.go index db0d97ec01..2ae5bf240c 100644 --- a/p2p/host/autorelay/relay.go +++ b/p2p/host/autorelay/relay.go @@ -5,6 +5,8 @@ import ( ) // Filter filters out all relay addresses. +// +// Deprecated: It is trivial for a user to implement this if they need this. func Filter(addrs []ma.Multiaddr) []ma.Multiaddr { raddrs := make([]ma.Multiaddr, 0, len(addrs)) for _, addr := range addrs { diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 525aa3a0fb..f7d3c5275a 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -267,7 +267,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { opts.HolePunchingOptions = append(hpOpts, opts.HolePunchingOptions...) } - h.hps, err = holepunch.NewService(h, h.ids, opts.HolePunchingOptions...) + h.hps, err = holepunch.NewService(h, h.ids, func() []ma.Multiaddr { + addrs := h.AllAddrs() + if opts.AddrsFactory != nil { + addrs = opts.AddrsFactory(addrs) + } + // AllAddrs may ignore observed addresses in favour of NAT mappings. Use both for hole punching. + addrs = append(addrs, h.ids.OwnObservedAddrs()...) + addrs = ma.Unique(addrs) + return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) + }, opts.HolePunchingOptions...) if err != nil { return nil, fmt.Errorf("failed to create hole punch service: %w", err) } @@ -280,20 +289,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { if opts.AddrsFactory != nil { h.AddrsFactory = opts.AddrsFactory } - // This is a terrible hack. - // We want to use this AddrsFactory for autonat. Wrapping AddrsFactory here ensures - // that autonat receives addresses with the correct certhashes. - // - // This logic cannot be in Addrs method as autonat cannot use the Addrs method directly. - // The autorelay package updates AddrsFactory to only provide p2p-circuit addresses when - // reachability is Private. - // - // Wrapping it here allows us to provide the wrapped AddrsFactory to autonat before - // autorelay updates it. - addrFactory := h.AddrsFactory - h.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { - return h.addCertHashes(addrFactory(addrs)) - } if opts.NATManager != nil { h.natmgr = opts.NATManager(n) @@ -832,16 +827,17 @@ func (h *BasicHost) ConnManager() connmgr.ConnManager { return h.cmgr } -// Addrs returns listening addresses that are safe to announce to the network. -// The output is the same as AllAddrs, but processed by AddrsFactory. +// Addrs returns listening addresses. The output is the same as AllAddrs, but +// processed by AddrsFactory. +// When used with AutoRelay, and if the host is not publicly reachable, +// this will only have host's private, relay, and no public addresses. func (h *BasicHost) Addrs() []ma.Multiaddr { - // We don't need to append certhashes here, the user provided addrsFactory was - // wrapped with addCertHashes in the constructor. addrs := h.AddrsFactory(h.AllAddrs()) // Make a copy. Consumers can modify the slice elements res := make([]ma.Multiaddr, len(addrs)) copy(res, addrs) - return ma.Unique(res) + // Add certhashes for the addresses provided by the user via address factory. + return h.addCertHashes(ma.Unique(res)) } // NormalizeMultiaddr returns a multiaddr suitable for equality checks. @@ -861,9 +857,9 @@ func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { return addr } +var p2pCircuitAddr = ma.StringCast("/p2p-circuit") + // AllAddrs returns all the addresses the host is listening on except circuit addresses. -// The output has webtransport addresses inferred from quic addresses. -// All the addresses have the correct func (h *BasicHost) AllAddrs() []ma.Multiaddr { listenAddrs := h.Network().ListenAddresses() if len(listenAddrs) == 0 { @@ -877,7 +873,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { // Iterate over all _unresolved_ listen addresses, resolving our primary // interface only to avoid advertising too many addresses. - var finalAddrs []ma.Multiaddr + finalAddrs := make([]ma.Multiaddr, 0, 8) if resolved, err := manet.ResolveUnspecifiedAddresses(listenAddrs, filteredIfaceAddrs); err != nil { // This can happen if we're listening on no addrs, or listening // on IPv6 addrs, but only have IPv4 interface addrs. @@ -956,6 +952,16 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { finalAddrs = append(finalAddrs, observedAddrs...) } finalAddrs = ma.Unique(finalAddrs) + // Remove /p2p-circuit addresses from the list. + // The p2p-circuit tranport listener reports its address as just /p2p-circuit + // This is useless for dialing. Users need to manage their circuit addresses themselves, + // or use AutoRelay. + finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool { + return a.Equal(p2pCircuitAddr) + }) + // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered + // using identify. + finalAddrs = h.addCertHashes(finalAddrs) return finalAddrs } diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 23593c7970..00a76023ed 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -3,6 +3,7 @@ package holepunch_test import ( "context" "net" + "slices" "sync" "testing" "time" @@ -94,7 +95,6 @@ func TestNoHolePunchIfDirectConnExists(t *testing.T) { require.GreaterOrEqual(t, nc1, 1) nc2 := len(h2.Network().ConnsToPeer(h1.ID())) require.GreaterOrEqual(t, nc2, 1) - require.NoError(t, hps.DirectConnect(h2.ID())) require.Len(t, h1.Network().ConnsToPeer(h2.ID()), nc1) require.Len(t, h2.Network().ConnsToPeer(h1.ID()), nc2) @@ -473,8 +473,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc hps = addHolePunchService(t, h2, h2opt...) } - // h1 has a relay addr - // h2 should connect to the relay addr + // h2 has a relay addr var raddr ma.Multiaddr for _, a := range h2.Addrs() { if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil { @@ -483,6 +482,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc } } require.NotEmpty(t, raddr) + // h1 should connect to the relay addr require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{ ID: h2.ID(), Addrs: []ma.Multiaddr{raddr}, @@ -492,7 +492,11 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service { t.Helper() - hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) + hps, err := holepunch.NewService(h, newMockIDService(t, h), func() []ma.Multiaddr { + addrs := h.Addrs() + addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) + return append(addrs, ma.StringCast("/ip4/1.2.3.4/tcp/1234")) + }, opts...) require.NoError(t, err) return hps } @@ -505,7 +509,6 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, libp2p.ResourceManager(&network.NullResourceManager{}), ) require.NoError(t, err) - hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) - require.NoError(t, err) + hps := addHolePunchService(t, h, opts...) return h, hps } diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index a30e653761..20d0558fc5 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -37,7 +37,8 @@ type holePuncher struct { host host.Host refCount sync.WaitGroup - ids identify.IDService + ids identify.IDService + listenAddrs func() []ma.Multiaddr // active hole punches for deduplicating activeMx sync.Mutex @@ -50,13 +51,14 @@ type holePuncher struct { filter AddrFilter } -func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer, filter AddrFilter) *holePuncher { +func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, tracer *tracer, filter AddrFilter) *holePuncher { hp := &holePuncher{ - host: h, - ids: ids, - active: make(map[peer.ID]struct{}), - tracer: tracer, - filter: filter, + host: h, + ids: ids, + active: make(map[peer.ID]struct{}), + tracer: tracer, + filter: filter, + listenAddrs: listenAddrs, } hp.ctx, hp.ctxCancel = context.WithCancel(context.Background()) h.Network().Notify((*netNotifiee)(hp)) @@ -102,16 +104,15 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { if getDirectConnection(hp.host, rp) != nil { return nil } - // short-circuit hole punching if a direct dial works. // attempt a direct connection ONLY if we have a public address for the remote peer for _, a := range hp.host.Peerstore().Addrs(rp) { - if manet.IsPublicAddr(a) && !isRelayAddress(a) { + if !isRelayAddress(a) && manet.IsPublicAddr(a) { forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching") dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) tstart := time.Now() - // This dials *all* public addresses from the peerstore. + // This dials *all* addresses, public and private, from the peerstore. err := hp.host.Connect(dialCtx, peer.AddrInfo{ID: rp}) dt := time.Since(tstart) cancel() @@ -206,7 +207,7 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr str.SetDeadline(time.Now().Add(StreamTimeout)) // send a CONNECT and start RTT measurement. - obsAddrs := removeRelayAddrs(hp.ids.OwnObservedAddrs()) + obsAddrs := removeRelayAddrs(hp.listenAddrs()) if hp.filter != nil { obsAddrs = hp.filter.FilterLocal(str.Conn().RemotePeer(), obsAddrs) } diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go index eb8ad9fd38..2e6fdd1a6a 100644 --- a/p2p/protocol/holepunch/svc.go +++ b/p2p/protocol/holepunch/svc.go @@ -8,18 +8,15 @@ import ( "time" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-msgio/pbio" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) // Protocol is the libp2p protocol for Hole Punching. @@ -47,7 +44,13 @@ type Service struct { ctxCancel context.CancelFunc host host.Host - ids identify.IDService + // ids helps with connection reversal. We wait for identify to complete and attempt + // a direct connection to the peer if it's publicly reachable. + ids identify.IDService + // listenAddrs provides the addresses for the host to be used for hole punching. We use this + // and not host.Addrs because host.Addrs might remove public unreachable address and only advertise + // publicly reachable relay addresses. + listenAddrs func() []ma.Multiaddr holePuncherMx sync.Mutex holePuncher *holePuncher @@ -65,7 +68,9 @@ type Service struct { // no matter if they are behind a NAT / firewall or not. // The Service handles DCUtR streams (which are initiated from the node behind // a NAT / Firewall once we establish a connection to them through a relay. -func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) { +// +// listenAddrs MUST only return public addresses. +func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, opts ...Option) (*Service, error) { if ids == nil { return nil, errors.New("identify service can't be nil") } @@ -76,6 +81,7 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, ctxCancel: cancel, host: h, ids: ids, + listenAddrs: listenAddrs, hasPublicAddrsChan: make(chan struct{}), } @@ -88,18 +94,18 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, s.tracer.Start() s.refCount.Add(1) - go s.watchForPublicAddr() + go s.waitForPublicAddr() return s, nil } -func (s *Service) watchForPublicAddr() { +func (s *Service) waitForPublicAddr() { defer s.refCount.Done() log.Debug("waiting until we have at least one public address", "peer", s.host.ID()) // TODO: We should have an event here that fires when identify discovers a new - // address (and when autonat confirms that address). + // address. // As we currently don't have an event like this, just check our observed addresses // regularly (exponential backoff starting at 250 ms, capped at 5s). duration := 250 * time.Millisecond @@ -107,7 +113,7 @@ func (s *Service) watchForPublicAddr() { t := time.NewTimer(duration) defer t.Stop() for { - if len(s.getPublicAddrs()) > 0 { + if len(s.listenAddrs()) > 0 { log.Debug("Host now has a public address. Starting holepunch protocol.") s.host.SetStreamHandler(Protocol, s.handleNewStream) break @@ -125,36 +131,20 @@ func (s *Service) watchForPublicAddr() { } } - // Only start the holePuncher if we're behind a NAT / firewall. - sub, err := s.host.EventBus().Subscribe(&event.EvtLocalReachabilityChanged{}, eventbus.Name("holepunch")) - if err != nil { - log.Debugf("failed to subscripe to Reachability event: %s", err) + s.holePuncherMx.Lock() + if s.ctx.Err() != nil { + // service is closed return } - defer sub.Close() - for { - select { - case <-s.ctx.Done(): - return - case e, ok := <-sub.Out(): - if !ok { - return - } - if e.(event.EvtLocalReachabilityChanged).Reachability != network.ReachabilityPrivate { - continue - } - s.holePuncherMx.Lock() - s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter) - s.holePuncherMx.Unlock() - close(s.hasPublicAddrsChan) - return - } - } + s.holePuncher = newHolePuncher(s.host, s.ids, s.listenAddrs, s.tracer, s.filter) + s.holePuncherMx.Unlock() + close(s.hasPublicAddrsChan) } // Close closes the Hole Punch Service. func (s *Service) Close() error { var err error + s.ctxCancel() s.holePuncherMx.Lock() if s.holePuncher != nil { err = s.holePuncher.Close() @@ -162,7 +152,6 @@ func (s *Service) Close() error { s.holePuncherMx.Unlock() s.tracer.Close() s.host.RemoveStreamHandler(Protocol) - s.ctxCancel() s.refCount.Wait() return err } @@ -172,7 +161,7 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, remo if !isRelayAddress(str.Conn().RemoteMultiaddr()) { return 0, nil, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr()) } - ownAddrs = s.getPublicAddrs() + ownAddrs = s.listenAddrs() if s.filter != nil { ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs) } @@ -275,29 +264,6 @@ func (s *Service) handleNewStream(str network.Stream) { s.tracer.HolePunchFinished("receiver", 1, addrs, ownAddrs, getDirectConnection(s.host, rp)) } -// getPublicAddrs returns public observed and interface addresses -func (s *Service) getPublicAddrs() []ma.Multiaddr { - addrs := removeRelayAddrs(s.ids.OwnObservedAddrs()) - - interfaceListenAddrs, err := s.host.Network().InterfaceListenAddresses() - if err != nil { - log.Debugf("failed to get to get InterfaceListenAddresses: %s", err) - } else { - addrs = append(addrs, interfaceListenAddrs...) - } - - addrs = ma.Unique(addrs) - - publicAddrs := make([]ma.Multiaddr, 0, len(addrs)) - - for _, addr := range addrs { - if manet.IsPublicAddr(addr) { - publicAddrs = append(publicAddrs, addr) - } - } - return publicAddrs -} - // DirectConnect is only exposed for testing purposes. // TODO: find a solution for this. func (s *Service) DirectConnect(p peer.ID) error { diff --git a/p2p/protocol/holepunch/util.go b/p2p/protocol/holepunch/util.go index 947b1ffd82..c0f34d0928 100644 --- a/p2p/protocol/holepunch/util.go +++ b/p2p/protocol/holepunch/util.go @@ -2,6 +2,7 @@ package holepunch import ( "context" + "slices" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -11,13 +12,7 @@ import ( ) func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, addr := range addrs { - if !isRelayAddress(addr) { - result = append(result, addr) - } - } - return result + return slices.DeleteFunc(addrs, isRelayAddress) } func isRelayAddress(a ma.Multiaddr) bool {