diff --git a/config/config.go b/config/config.go index 90fa9c16cf..0634ef4a36 100644 --- a/config/config.go +++ b/config/config.go @@ -155,6 +155,9 @@ func (cfg *Config) makeSwarm() (*swarm.Swarm, error) { if cfg.ResourceManager != nil { opts = append(opts, swarm.WithResourceManager(cfg.ResourceManager)) } + if cfg.MultiaddrResolver != nil { + opts = append(opts, swarm.WithMultiaddrResolver(cfg.MultiaddrResolver)) + } // TODO: Make the swarm implementation configurable. return swarm.NewSwarm(pid, cfg.Peerstore, opts...) } @@ -229,7 +232,6 @@ func (cfg *Config) NewNode() (host.Host, error) { EnablePing: !cfg.DisablePing, UserAgent: cfg.UserAgent, ProtocolVersion: cfg.ProtocolVersion, - MultiaddrResolver: cfg.MultiaddrResolver, EnableHolePunching: cfg.EnableHolePunching, HolePunchingOptions: cfg.HolePunchingOptions, EnableRelayService: cfg.EnableRelayService, diff --git a/core/transport/transport.go b/core/transport/transport.go index 379e9d6d4a..ad2ee66496 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -77,6 +77,12 @@ type Transport interface { Proxy() bool } +// Resolver can be optionally implemented by transports that want to resolve or transform the +// multiaddr. +type Resolver interface { + Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) +} + // Listener is an interface closely resembling the net.Listener interface. The // only real difference is that Accept() returns Conn's of the type in this // package, and also exposes a Multiaddr method as opposed to a regular Addr diff --git a/go.mod b/go.mod index 20221b2805..54ff872afa 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/minio/sha256-simd v1.0.0 github.com/mr-tron/base58 v1.2.0 github.com/multiformats/go-base32 v0.0.4 - github.com/multiformats/go-multiaddr v0.6.0 + github.com/multiformats/go-multiaddr v0.7.0 github.com/multiformats/go-multiaddr-dns v0.3.1 github.com/multiformats/go-multiaddr-fmt v0.1.0 github.com/multiformats/go-multibase v0.1.1 diff --git a/go.sum b/go.sum index 3bc941c299..55fafd71e0 100644 --- a/go.sum +++ b/go.sum @@ -369,8 +369,8 @@ github.com/multiformats/go-base36 v0.1.0 h1:JR6TyF7JjGd3m6FbLU2cOxhC0Li8z8dLNGQ8 github.com/multiformats/go-base36 v0.1.0/go.mod h1:kFGE83c6s80PklsHO9sRn2NCoffoRdUUOENyW/Vv6sM= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= github.com/multiformats/go-multiaddr v0.2.0/go.mod h1:0nO36NvPpyV4QzvTLi/lafl2y95ncPj0vFwVF6k6wJ4= -github.com/multiformats/go-multiaddr v0.6.0 h1:qMnoOPj2s8xxPU5kZ57Cqdr0hHhARz7mFsPMIiYNqzg= -github.com/multiformats/go-multiaddr v0.6.0/go.mod h1:F4IpaKZuPP360tOMn2Tpyu0At8w23aRyVqeK0DbFeGM= +github.com/multiformats/go-multiaddr v0.7.0 h1:gskHcdaCyPtp9XskVwtvEeQOG465sCohbQIirSyqxrc= +github.com/multiformats/go-multiaddr v0.7.0/go.mod h1:Fs50eBDWvZu+l3/9S6xAE7ZYj6yhxlvaVZjakWN7xRs= github.com/multiformats/go-multiaddr-dns v0.3.1 h1:QgQgR+LQVt3NPTjbrLLpsaT2ufAA2y0Mkk+QRVJbW3A= github.com/multiformats/go-multiaddr-dns v0.3.1/go.mod h1:G/245BRQ6FJGmryJCrOuTdB37AMA5AMOVuO6NY3JwTk= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 9997876395..19fff7d2b4 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -38,10 +38,6 @@ import ( msmux "github.com/multiformats/go-multistream" ) -// The maximum number of address resolution steps we'll perform for a single -// peer (for all addresses). -const maxAddressResolution = 32 - // addrChangeTickrInterval is the interval between two address change ticks. var addrChangeTickrInterval = 5 * time.Second @@ -713,77 +709,9 @@ func (h *BasicHost) Connect(ctx context.Context, pi peer.AddrInfo) error { } } - resolved, err := h.resolveAddrs(ctx, h.Peerstore().PeerInfo(pi.ID)) - if err != nil { - return err - } - h.Peerstore().AddAddrs(pi.ID, resolved, peerstore.TempAddrTTL) - return h.dialPeer(ctx, pi.ID) } -func (h *BasicHost) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) { - proto := ma.ProtocolWithCode(ma.P_P2P).Name - p2paddr, err := ma.NewMultiaddr("/" + proto + "/" + pi.ID.Pretty()) - if err != nil { - return nil, err - } - - resolveSteps := 0 - - // Recursively resolve all addrs. - // - // While the toResolve list is non-empty: - // * Pop an address off. - // * If the address is fully resolved, add it to the resolved list. - // * Otherwise, resolve it and add the results to the "to resolve" list. - toResolve := append(([]ma.Multiaddr)(nil), pi.Addrs...) - resolved := make([]ma.Multiaddr, 0, len(pi.Addrs)) - for len(toResolve) > 0 { - // pop the last addr off. - addr := toResolve[len(toResolve)-1] - toResolve = toResolve[:len(toResolve)-1] - - // if it's resolved, add it to the resolved list. - if !madns.Matches(addr) { - resolved = append(resolved, addr) - continue - } - - resolveSteps++ - - // We've resolved too many addresses. We can keep all the fully - // resolved addresses but we'll need to skip the rest. - if resolveSteps >= maxAddressResolution { - log.Warnf( - "peer %s asked us to resolve too many addresses: %s/%s", - pi.ID, - resolveSteps, - maxAddressResolution, - ) - continue - } - - // otherwise, resolve it - reqaddr := addr.Encapsulate(p2paddr) - resaddrs, err := h.maResolver.Resolve(ctx, reqaddr) - if err != nil { - log.Infof("error resolving %s: %s", reqaddr, err) - } - - // add the results to the toResolve list. - for _, res := range resaddrs { - pi, err := peer.AddrInfoFromP2pAddr(res) - if err != nil { - log.Infof("error parsing %s: %s", res, err) - } - toResolve = append(toResolve, pi.Addrs...) - } - } - - return resolved, nil -} - // dialPeer opens a connection to peer, and makes sure to identify // the connection once it has been opened. func (h *BasicHost) dialPeer(ctx context.Context, p peer.ID) error { diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 02243c23b4..b1c2cf6203 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -17,14 +17,12 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/record" - "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/eventbus" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/identify" ma "github.com/multiformats/go-multiaddr" - madns "github.com/multiformats/go-multiaddr-dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -526,111 +524,6 @@ func TestProtoDowngrade(t *testing.T) { assertWait(t, connectedOn, "/testing") } -func TestAddrResolution(t *testing.T) { - ctx := context.Background() - - p1 := test.RandPeerIDFatal(t) - p2 := test.RandPeerIDFatal(t) - addr1 := ma.StringCast("/dnsaddr/example.com") - addr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123") - p2paddr1 := ma.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty()) - p2paddr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123/p2p/" + p1.Pretty()) - p2paddr3 := ma.StringCast("/ip4/192.0.2.1/tcp/123/p2p/" + p2.Pretty()) - - backend := &madns.MockResolver{ - TXT: map[string][]string{"_dnsaddr.example.com": { - "dnsaddr=" + p2paddr2.String(), "dnsaddr=" + p2paddr3.String(), - }}, - } - resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend)) - require.NoError(t, err) - - h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{MultiaddrResolver: resolver}) - require.NoError(t, err) - defer h.Close() - - pi, err := peer.AddrInfoFromP2pAddr(p2paddr1) - require.NoError(t, err) - - tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() - _ = h.Connect(tctx, *pi) - - addrs := h.Peerstore().Addrs(pi.ID) - - require.Len(t, addrs, 2) - require.Contains(t, addrs, addr1) - require.Contains(t, addrs, addr2) -} - -func TestAddrResolutionRecursive(t *testing.T) { - ctx := context.Background() - - p1, err := test.RandPeerID() - if err != nil { - t.Error(err) - } - p2, err := test.RandPeerID() - if err != nil { - t.Error(err) - } - addr1 := ma.StringCast("/dnsaddr/example.com") - addr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123") - p2paddr1 := ma.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty()) - p2paddr2 := ma.StringCast("/dnsaddr/example.com/p2p/" + p2.Pretty()) - p2paddr1i := ma.StringCast("/dnsaddr/foo.example.com/p2p/" + p1.Pretty()) - p2paddr2i := ma.StringCast("/dnsaddr/bar.example.com/p2p/" + p2.Pretty()) - p2paddr1f := ma.StringCast("/ip4/192.0.2.1/tcp/123/p2p/" + p1.Pretty()) - - backend := &madns.MockResolver{ - TXT: map[string][]string{ - "_dnsaddr.example.com": { - "dnsaddr=" + p2paddr1i.String(), - "dnsaddr=" + p2paddr2i.String(), - }, - "_dnsaddr.foo.example.com": { - "dnsaddr=" + p2paddr1f.String(), - }, - "_dnsaddr.bar.example.com": { - "dnsaddr=" + p2paddr2i.String(), - }, - }, - } - resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend)) - if err != nil { - t.Fatal(err) - } - - h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{MultiaddrResolver: resolver}) - require.NoError(t, err) - defer h.Close() - - pi1, err := peer.AddrInfoFromP2pAddr(p2paddr1) - if err != nil { - t.Error(err) - } - - tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() - _ = h.Connect(tctx, *pi1) - - addrs1 := h.Peerstore().Addrs(pi1.ID) - require.Len(t, addrs1, 2) - require.Contains(t, addrs1, addr1) - require.Contains(t, addrs1, addr2) - - pi2, err := peer.AddrInfoFromP2pAddr(p2paddr2) - if err != nil { - t.Error(err) - } - - _ = h.Connect(tctx, *pi2) - - addrs2 := h.Peerstore().Addrs(pi2.ID) - require.Len(t, addrs2, 1) - require.Contains(t, addrs2, addr1) -} - func TestAddrChangeImmediatelyIfAddressNonEmpty(t *testing.T) { ctx := context.Background() taddrs := []ma.Multiaddr{ma.StringCast("/ip4/1.2.3.4/tcp/1234")} diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go index aed15d4ad3..9b8c9e93a7 100644 --- a/p2p/net/swarm/dial_test.go +++ b/p2p/net/swarm/dial_test.go @@ -11,17 +11,18 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" testutil "github.com/libp2p/go-libp2p/core/test" - . "github.com/libp2p/go-libp2p/p2p/net/swarm" + "github.com/libp2p/go-libp2p/p2p/net/swarm" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ma "github.com/multiformats/go-multiaddr" + madns "github.com/multiformats/go-multiaddr-dns" manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) -func closeSwarms(swarms []*Swarm) { +func closeSwarms(swarms []*swarm.Swarm) { for _, s := range swarms { s.Close() } @@ -45,6 +46,44 @@ func TestBasicDialPeer(t *testing.T) { s.Close() } +func TestBasicDialPeerWithResolver(t *testing.T) { + t.Parallel() + + mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)} + ipaddr, err := net.ResolveIPAddr("ip4", "127.0.0.1") + require.NoError(t, err) + mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr} + resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver)) + require.NoError(t, err) + + swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(resolver))) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + // Change the multiaddr from /ip4/127.0.0.1/... to /dns4/example.com/... so + // that the resovler has to resolve this + var s2Addrs []ma.Multiaddr + for _, a := range s2.ListenAddresses() { + _, rest := ma.SplitFunc(a, func(c ma.Component) bool { + return c.Protocol().Code == ma.P_TCP || c.Protocol().Code == ma.P_UDP + }, + ) + if rest != nil { + s2Addrs = append(s2Addrs, ma.StringCast("/dns4/example.com").Encapsulate(rest)) + } + } + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2Addrs, peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + s, err := c.NewStream(context.Background()) + require.NoError(t, err) + s.Close() +} + func TestDialWithNoListeners(t *testing.T) { t.Parallel() @@ -90,7 +129,7 @@ func TestSimultDials(t *testing.T) { { var wg sync.WaitGroup errs := make(chan error, 20) // 2 connect calls in each of the 10 for-loop iterations - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *swarm.Swarm, dst peer.ID, addr ma.Multiaddr) { // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) @@ -157,7 +196,7 @@ func newSilentPeer(t *testing.T) (peer.ID, ma.Multiaddr, net.Listener) { func TestDialWait(t *testing.T) { const dialTimeout = 250 * time.Millisecond - swarms := makeSwarms(t, 1, swarmt.DialTimeout(dialTimeout)) + swarms := makeSwarms(t, 1, swarmt.WithSwarmOpts(swarm.WithDialTimeout(dialTimeout))) s1 := swarms[0] defer s1.Close() @@ -176,11 +215,11 @@ func TestDialWait(t *testing.T) { } duration := time.Since(before) - if duration < dialTimeout*DialAttempts { - t.Error("< dialTimeout * DialAttempts not being respected", duration, dialTimeout*DialAttempts) + if duration < dialTimeout*swarm.DialAttempts { + t.Error("< dialTimeout * DialAttempts not being respected", duration, dialTimeout*swarm.DialAttempts) } - if duration > 2*dialTimeout*DialAttempts { - t.Error("> 2*dialTimeout * DialAttempts not being respected", duration, 2*dialTimeout*DialAttempts) + if duration > 2*dialTimeout*swarm.DialAttempts { + t.Error("> 2*dialTimeout * DialAttempts not being respected", duration, 2*dialTimeout*swarm.DialAttempts) } if !s1.Backoff().Backoff(s2p, s2addr) { @@ -197,7 +236,7 @@ func TestDialBackoff(t *testing.T) { const dialTimeout = 100 * time.Millisecond ctx := context.Background() - swarms := makeSwarms(t, 2, swarmt.DialTimeout(dialTimeout)) + swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithDialTimeout(dialTimeout))) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] @@ -400,7 +439,7 @@ func TestDialBackoffClears(t *testing.T) { t.Parallel() const dialTimeout = 250 * time.Millisecond - swarms := makeSwarms(t, 2, swarmt.DialTimeout(dialTimeout)) + swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithDialTimeout(dialTimeout))) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] @@ -418,11 +457,11 @@ func TestDialBackoffClears(t *testing.T) { require.Error(t, err, "dialing to broken addr worked...") duration := time.Since(before) - if duration < dialTimeout*DialAttempts { - t.Error("< dialTimeout * DialAttempts not being respected", duration, dialTimeout*DialAttempts) + if duration < dialTimeout*swarm.DialAttempts { + t.Error("< dialTimeout * DialAttempts not being respected", duration, dialTimeout*swarm.DialAttempts) } - if duration > 2*dialTimeout*DialAttempts { - t.Error("> 2*dialTimeout * DialAttempts not being respected", duration, 2*dialTimeout*DialAttempts) + if duration > 2*dialTimeout*swarm.DialAttempts { + t.Error("> 2*dialTimeout * DialAttempts not being respected", duration, 2*dialTimeout*swarm.DialAttempts) } require.True(t, s1.Backoff().Backoff(s2.LocalPeer(), s2bad), "s2 should now be on backoff") @@ -441,7 +480,7 @@ func TestDialBackoffClears(t *testing.T) { func TestDialPeerFailed(t *testing.T) { t.Parallel() - swarms := makeSwarms(t, 2, swarmt.DialTimeout(100*time.Millisecond)) + swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithDialTimeout(100*time.Millisecond))) defer closeSwarms(swarms) testedSwarm, targetSwarm := swarms[0], swarms[1] @@ -462,7 +501,7 @@ func TestDialPeerFailed(t *testing.T) { // * [/ip4/127.0.0.1/tcp/34881] failed to negotiate security protocol: context deadline exceeded // ... - dialErr, ok := err.(*DialError) + dialErr, ok := err.(*swarm.DialError) if !ok { t.Fatalf("expected *DialError, got %T", err) } @@ -515,7 +554,7 @@ func newSilentListener(t *testing.T) ([]ma.Multiaddr, net.Listener) { func TestDialSimultaneousJoin(t *testing.T) { const dialTimeout = 250 * time.Millisecond - swarms := makeSwarms(t, 2, swarmt.DialTimeout(dialTimeout)) + swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithDialTimeout(dialTimeout))) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] @@ -617,5 +656,5 @@ func TestDialSelf(t *testing.T) { s1 := swarms[0] _, err := s1.DialPeer(context.Background(), s1.LocalPeer()) - require.ErrorIs(t, err, ErrDialToSelf, "expected error from self dial") + require.ErrorIs(t, err, swarm.ErrDialToSelf, "expected error from self dial") } diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 80d67eb979..ece0986dfa 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -19,6 +19,7 @@ import ( logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" + madns "github.com/multiformats/go-multiaddr-dns" ) const ( @@ -54,6 +55,14 @@ func WithConnectionGater(gater connmgr.ConnectionGater) Option { } } +// WithMultiaddrResolver sets a custom multiaddress resolver +func WithMultiaddrResolver(maResolver *madns.Resolver) Option { + return func(s *Swarm) error { + s.maResolver = maResolver + return nil + } +} + // WithMetrics sets a metrics reporter func WithMetrics(reporter metrics.Reporter) Option { return func(s *Swarm) error { @@ -127,6 +136,8 @@ type Swarm struct { m map[int]transport.Transport } + maResolver *madns.Resolver + // stream handlers streamh atomic.Value @@ -153,6 +164,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) (*Swarm, ctxCancel: cancel, dialTimeout: defaultDialTimeout, dialTimeoutLocal: defaultDialTimeoutLocal, + maResolver: madns.DefaultResolver, } s.conns.m = make(map[peer.ID][]*Conn) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 0e232a1881..bba2c462f8 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -10,12 +10,18 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + madns "github.com/multiformats/go-multiaddr-dns" manet "github.com/multiformats/go-multiaddr/net" ) +// The maximum number of address resolution steps we'll perform for a single +// peer (for all addresses). +const maxAddressResolution = 32 + // Diagram of dial sync: // // many callers of Dial() synched w. dials many addrs results to callers @@ -292,7 +298,32 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er return nil, ErrNoAddresses } - goodAddrs := s.filterKnownUndialables(p, peerAddrs) + peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs)) + for _, a := range peerAddrs { + tpt := s.TransportForDialing(a) + resolver, ok := tpt.(transport.Resolver) + if ok { + resolvedAddrs, err := resolver.Resolve(ctx, a) + if err != nil { + log.Warnf("Failed to resolve multiaddr %s by transport %v: %v", a, tpt, err) + continue + } + peerAddrsAfterTransportResolved = append(peerAddrsAfterTransportResolved, resolvedAddrs...) + } else { + peerAddrsAfterTransportResolved = append(peerAddrsAfterTransportResolved, a) + } + } + + // Resolve dns or dnsaddrs + resolved, err := s.resolveAddrs(ctx, peer.AddrInfo{ + ID: p, + Addrs: peerAddrsAfterTransportResolved, + }) + if err != nil { + return nil, err + } + + goodAddrs := s.filterKnownUndialables(p, resolved) if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) } @@ -301,7 +332,71 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er return nil, ErrNoGoodAddresses } - return goodAddrs, nil + s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL) + + return resolved, nil +} + +func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) { + proto := ma.ProtocolWithCode(ma.P_P2P).Name + p2paddr, err := ma.NewMultiaddr("/" + proto + "/" + pi.ID.Pretty()) + if err != nil { + return nil, err + } + + resolveSteps := 0 + + // Recursively resolve all addrs. + // + // While the toResolve list is non-empty: + // * Pop an address off. + // * If the address is fully resolved, add it to the resolved list. + // * Otherwise, resolve it and add the results to the "to resolve" list. + toResolve := append(([]ma.Multiaddr)(nil), pi.Addrs...) + resolved := make([]ma.Multiaddr, 0, len(pi.Addrs)) + for len(toResolve) > 0 { + // pop the last addr off. + addr := toResolve[len(toResolve)-1] + toResolve = toResolve[:len(toResolve)-1] + + // if it's resolved, add it to the resolved list. + if !madns.Matches(addr) { + resolved = append(resolved, addr) + continue + } + + resolveSteps++ + + // We've resolved too many addresses. We can keep all the fully + // resolved addresses but we'll need to skip the rest. + if resolveSteps >= maxAddressResolution { + log.Warnf( + "peer %s asked us to resolve too many addresses: %s/%s", + pi.ID, + resolveSteps, + maxAddressResolution, + ) + continue + } + + // otherwise, resolve it + reqaddr := addr.Encapsulate(p2paddr) + resaddrs, err := s.maResolver.Resolve(ctx, reqaddr) + if err != nil { + log.Infof("error resolving %s: %s", reqaddr, err) + } + + // add the results to the toResolve list. + for _, res := range resaddrs { + pi, err := peer.AddrInfoFromP2pAddr(res) + if err != nil { + log.Infof("error parsing %s: %s", res, err) + } + toResolve = append(toResolve, pi.Addrs...) + } + } + + return resolved, nil } func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialResult) error { diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go new file mode 100644 index 0000000000..8979433669 --- /dev/null +++ b/p2p/net/swarm/swarm_dial_test.go @@ -0,0 +1,193 @@ +package swarm + +import ( + "context" + "crypto/rand" + "net" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/test" + "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" + "github.com/libp2p/go-libp2p/p2p/transport/tcp" + "github.com/libp2p/go-libp2p/p2p/transport/websocket" + "github.com/multiformats/go-multiaddr" + madns "github.com/multiformats/go-multiaddr-dns" + "github.com/stretchr/testify/require" +) + +func TestAddrsForDial(t *testing.T) { + mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)} + ipaddr, err := net.ResolveIPAddr("ip4", "1.2.3.4") + if err != nil { + t.Fatal(err) + } + mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr} + + resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver)) + if err != nil { + t.Fatal(err) + } + + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + + ps, err := pstoremem.NewPeerstore() + require.NoError(t, err) + ps.AddPubKey(id, priv.GetPublic()) + ps.AddPrivKey(id, priv) + t.Cleanup(func() { ps.Close() }) + + tpt, err := websocket.New(nil, network.NullResourceManager) + require.NoError(t, err) + s, err := NewSwarm(id, ps, WithMultiaddrResolver(resolver)) + require.NoError(t, err) + defer s.Close() + err = s.AddTransport(tpt) + require.NoError(t, err) + + otherPeer := test.RandPeerIDFatal(t) + + ps.AddAddr(otherPeer, multiaddr.StringCast("/dns4/example.com/tcp/1234/wss"), time.Hour) + + ctx := context.Background() + mas, err := s.addrsForDial(ctx, otherPeer) + require.NoError(t, err) + + require.NotZero(t, len(mas)) +} + +func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + ps, err := pstoremem.NewPeerstore() + require.NoError(t, err) + ps.AddPubKey(id, priv.GetPublic()) + ps.AddPrivKey(id, priv) + t.Cleanup(func() { ps.Close() }) + s, err := NewSwarm(id, ps, WithMultiaddrResolver(resolver)) + require.NoError(t, err) + t.Cleanup(func() { + s.Close() + }) + + // Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out. + tpt, err := tcp.NewTCPTransport(nil, network.NullResourceManager) + require.NoError(t, err) + err = s.AddTransport(tpt) + require.NoError(t, err) + + return s +} + +func TestAddrResolution(t *testing.T) { + ctx := context.Background() + + p1 := test.RandPeerIDFatal(t) + p2 := test.RandPeerIDFatal(t) + addr1 := multiaddr.StringCast("/dnsaddr/example.com") + addr2 := multiaddr.StringCast("/ip4/192.0.2.1/tcp/123") + + p2paddr2 := multiaddr.StringCast("/ip4/192.0.2.1/tcp/123/p2p/" + p1.Pretty()) + p2paddr3 := multiaddr.StringCast("/ip4/192.0.2.1/tcp/123/p2p/" + p2.Pretty()) + + backend := &madns.MockResolver{ + TXT: map[string][]string{"_dnsaddr.example.com": { + "dnsaddr=" + p2paddr2.String(), "dnsaddr=" + p2paddr3.String(), + }}, + } + resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend)) + require.NoError(t, err) + + s := newTestSwarmWithResolver(t, resolver) + + s.peers.AddAddr(p1, addr1, time.Hour) + + tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + mas, err := s.addrsForDial(tctx, p1) + require.NoError(t, err) + + require.Len(t, mas, 1) + require.Contains(t, mas, addr2) + + addrs := s.peers.Addrs(p1) + require.Len(t, addrs, 2) + require.Contains(t, addrs, addr1) + require.Contains(t, addrs, addr2) +} + +func TestAddrResolutionRecursive(t *testing.T) { + ctx := context.Background() + + p1, err := test.RandPeerID() + if err != nil { + t.Error(err) + } + p2, err := test.RandPeerID() + if err != nil { + t.Error(err) + } + addr1 := multiaddr.StringCast("/dnsaddr/example.com") + addr2 := multiaddr.StringCast("/ip4/192.0.2.1/tcp/123") + p2paddr1 := multiaddr.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty()) + p2paddr2 := multiaddr.StringCast("/dnsaddr/example.com/p2p/" + p2.Pretty()) + p2paddr1i := multiaddr.StringCast("/dnsaddr/foo.example.com/p2p/" + p1.Pretty()) + p2paddr2i := multiaddr.StringCast("/dnsaddr/bar.example.com/p2p/" + p2.Pretty()) + p2paddr1f := multiaddr.StringCast("/ip4/192.0.2.1/tcp/123/p2p/" + p1.Pretty()) + + backend := &madns.MockResolver{ + TXT: map[string][]string{ + "_dnsaddr.example.com": { + "dnsaddr=" + p2paddr1i.String(), + "dnsaddr=" + p2paddr2i.String(), + }, + "_dnsaddr.foo.example.com": { + "dnsaddr=" + p2paddr1f.String(), + }, + "_dnsaddr.bar.example.com": { + "dnsaddr=" + p2paddr2i.String(), + }, + }, + } + resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend)) + if err != nil { + t.Fatal(err) + } + + s := newTestSwarmWithResolver(t, resolver) + + pi1, err := peer.AddrInfoFromP2pAddr(p2paddr1) + require.NoError(t, err) + + tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + s.Peerstore().AddAddrs(pi1.ID, pi1.Addrs, peerstore.TempAddrTTL) + _, err = s.addrsForDial(tctx, p1) + require.NoError(t, err) + + addrs1 := s.Peerstore().Addrs(pi1.ID) + require.Len(t, addrs1, 2) + require.Contains(t, addrs1, addr1) + require.Contains(t, addrs1, addr2) + + pi2, err := peer.AddrInfoFromP2pAddr(p2paddr2) + require.NoError(t, err) + + s.Peerstore().AddAddrs(pi2.ID, pi2.Addrs, peerstore.TempAddrTTL) + _, err = s.addrsForDial(tctx, p2) + // This never resolves to a good address + require.Equal(t, ErrNoGoodAddresses, err) + + addrs2 := s.Peerstore().Addrs(pi2.ID) + require.Len(t, addrs2, 1) + require.Contains(t, addrs2, addr1) +} diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index f25ee1f1c8..95aa71d756 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/p2p/net/swarm" . "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" @@ -407,7 +408,7 @@ func TestPreventDialListenAddr(t *testing.T) { break } } - remote := peer.ID("foobar") + remote := test.RandPeerIDFatal(t) s.Peerstore().AddAddr(remote, addr, time.Hour) _, err = s.DialPeer(context.Background(), remote) if !errors.Is(err, swarm.ErrNoGoodAddresses) { @@ -461,11 +462,11 @@ func TestResourceManager(t *testing.T) { defer ctrl.Finish() rcmgr1 := mocknetwork.NewMockResourceManager(ctrl) - s1 := GenSwarm(t, OptResourceManager(rcmgr1)) + s1 := GenSwarm(t, WithSwarmOpts(swarm.WithResourceManager(rcmgr1))) defer s1.Close() rcmgr2 := mocknetwork.NewMockResourceManager(ctrl) - s2 := GenSwarm(t, OptResourceManager(rcmgr2)) + s2 := GenSwarm(t, WithSwarmOpts(swarm.WithResourceManager(rcmgr2))) defer s2.Close() connectSwarms(t, context.Background(), []*swarm.Swarm{s1, s2}) @@ -496,7 +497,7 @@ func TestResourceManagerNewStream(t *testing.T) { defer ctrl.Finish() rcmgr1 := mocknetwork.NewMockResourceManager(ctrl) - s1 := GenSwarm(t, OptResourceManager(rcmgr1)) + s1 := GenSwarm(t, WithSwarmOpts(swarm.WithResourceManager(rcmgr1))) defer s1.Close() s2 := GenSwarm(t) @@ -515,11 +516,11 @@ func TestResourceManagerAcceptStream(t *testing.T) { defer ctrl.Finish() rcmgr1 := mocknetwork.NewMockResourceManager(ctrl) - s1 := GenSwarm(t, OptResourceManager(rcmgr1)) + s1 := GenSwarm(t, WithSwarmOpts(swarm.WithResourceManager(rcmgr1))) defer s1.Close() rcmgr2 := mocknetwork.NewMockResourceManager(ctrl) - s2 := GenSwarm(t, OptResourceManager(rcmgr2)) + s2 := GenSwarm(t, WithSwarmOpts(swarm.WithResourceManager(rcmgr2))) defer s2.Close() s2.SetStreamHandler(func(str network.Stream) { t.Fatal("didn't expect to accept a stream") }) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index c3704e5d28..604d1d0a47 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -32,10 +32,9 @@ type config struct { dialOnly bool disableTCP bool disableQUIC bool - dialTimeout time.Duration connectionGater connmgr.ConnectionGater - rcmgr network.ResourceManager sk crypto.PrivKey + swarmOpts []swarm.Option clock } @@ -59,6 +58,12 @@ func WithClock(clock clock) Option { } } +func WithSwarmOpts(swarmOpts ...swarm.Option) Option { + return func(_ *testing.T, c *config) { + c.swarmOpts = swarmOpts + } +} + // OptDisableReuseport disables reuseport in this test swarm. var OptDisableReuseport Option = func(_ *testing.T, c *config) { c.disableReuseport = true @@ -86,12 +91,6 @@ func OptConnGater(cg connmgr.ConnectionGater) Option { } } -func OptResourceManager(rcmgr network.ResourceManager) Option { - return func(_ *testing.T, c *config) { - c.rcmgr = rcmgr - } -} - // OptPeerPrivateKey configures the peer private key which is then used to derive the public key and peer ID. func OptPeerPrivateKey(sk crypto.PrivKey) Option { return func(_ *testing.T, c *config) { @@ -99,12 +98,6 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option { } } -func DialTimeout(t time.Duration) Option { - return func(_ *testing.T, c *config) { - c.dialTimeout = t - } -} - // GenUpgrader creates a new connection upgrader for use with this swarm. func GenUpgrader(t *testing.T, n *swarm.Swarm, opts ...tptu.Option) transport.Upgrader { id := n.LocalPeer() @@ -144,16 +137,11 @@ func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - swarmOpts := []swarm.Option{swarm.WithMetrics(metrics.NewBandwidthCounter())} + swarmOpts := cfg.swarmOpts + swarmOpts = append(swarmOpts, swarm.WithMetrics(metrics.NewBandwidthCounter())) if cfg.connectionGater != nil { swarmOpts = append(swarmOpts, swarm.WithConnectionGater(cfg.connectionGater)) } - if cfg.rcmgr != nil { - swarmOpts = append(swarmOpts, swarm.WithResourceManager(cfg.rcmgr)) - } - if cfg.dialTimeout != 0 { - swarmOpts = append(swarmOpts, swarm.WithDialTimeout(cfg.dialTimeout)) - } s, err := swarm.NewSwarm(id, ps, swarmOpts...) require.NoError(t, err) diff --git a/p2p/transport/websocket/addrs.go b/p2p/transport/websocket/addrs.go index 608eb2d0da..fed649dcbc 100644 --- a/p2p/transport/websocket/addrs.go +++ b/p2p/transport/websocket/addrs.go @@ -105,23 +105,17 @@ func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) { } func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) { - // Only look at the _last_ component. - maddr, wscomponent := ma.SplitLast(maddr) - if maddr == nil || wscomponent == nil { - return nil, fmt.Errorf("websocket addrs need at least two components") + parsed, err := parseWebsocketMultiaddr(maddr) + if err != nil { + return nil, err } - var scheme string - switch wscomponent.Protocol().Code { - case ma.P_WS: - scheme = "ws" - case ma.P_WSS: + scheme := "ws" + if parsed.isWSS { scheme = "wss" - default: - return nil, fmt.Errorf("not a websocket multiaddr") } - network, host, err := manet.DialArgs(maddr) + network, host, err := manet.DialArgs(parsed.restMultiaddr) if err != nil { return nil, err } @@ -135,3 +129,47 @@ func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) { Host: host, }, nil } + +type parsedWebsocketMultiaddr struct { + isWSS bool + // sni is the SNI value for the TLS handshake + sni *ma.Component + // the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss + restMultiaddr ma.Multiaddr +} + +func parseWebsocketMultiaddr(a ma.Multiaddr) (parsedWebsocketMultiaddr, error) { + out := parsedWebsocketMultiaddr{} + // First check if we have a WSS component. If so we'll canonicalize it into a /tls/ws + withoutWss := a.Decapsulate(wssComponent) + if !withoutWss.Equal(a) { + a = withoutWss.Encapsulate(tlsWsComponent) + } + + // Remove the ws component + withoutWs := a.Decapsulate(wsComponent) + if withoutWs.Equal(a) { + return out, fmt.Errorf("not a websocket multiaddr") + } + + rest := withoutWs + // If this is not a wss then withoutWs is the rest of the multiaddr + out.restMultiaddr = withoutWs + for { + var head *ma.Component + rest, head = ma.SplitLast(rest) + if head == nil || rest == nil { + break + } + + if head.Protocol().Code == ma.P_SNI { + out.sni = head + } else if head.Protocol().Code == ma.P_TLS { + out.isWSS = true + out.restMultiaddr = rest + break + } + } + + return out, nil +} diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index b94bed798c..128fdf5eb5 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -10,11 +10,6 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) -var ( - wsma = ma.StringCast("/ws") - wssma = ma.StringCast("/wss") -) - type listener struct { nl net.Listener server http.Server @@ -25,16 +20,31 @@ type listener struct { incoming chan *Conn } +func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { + if !pwma.isWSS { + return pwma.restMultiaddr.Encapsulate(wsComponent) + } + + if pwma.sni == nil { + return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(wsComponent) + } + + return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(pwma.sni).Encapsulate(wsComponent) +} + // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { - // Only look at the _last_ component. - maddr, wscomponent := ma.SplitLast(a) - isWSS := wscomponent.Equal(wssma) - if isWSS && tlsConf == nil { + parsed, err := parseWebsocketMultiaddr(a) + if err != nil { + return nil, err + } + + if parsed.isWSS && tlsConf == nil { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(maddr) + + lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) if err != nil { return nil, err } @@ -54,15 +64,16 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { _, last := ma.SplitFirst(laddr) laddr = first.Encapsulate(last) } + parsed.restMultiaddr = laddr ln := &listener{ nl: nl, - laddr: laddr.Encapsulate(wscomponent), + laddr: parsed.toMultiaddr(), incoming: make(chan *Conn), closed: make(chan struct{}), } ln.server = http.Server{Handler: ln} - if isWSS { + if parsed.isWSS { ln.server.TLSConfig = tlsConf } return ln, nil diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 42c4c618f4..04af9b8c8e 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -21,9 +21,27 @@ import ( // WsFmt is multiaddr formatter for WsProtocol var WsFmt = mafmt.And(mafmt.TCP, mafmt.Base(ma.P_WS)) -// This is _not_ WsFmt because we want the transport to stick to dialing fully -// resolved addresses. -var dialMatcher = mafmt.And(mafmt.Or(mafmt.IP, mafmt.DNS), mafmt.Base(ma.P_TCP), mafmt.Or(mafmt.Base(ma.P_WS), mafmt.Base(ma.P_WSS))) +var dialMatcher = mafmt.And( + mafmt.Or(mafmt.IP, mafmt.DNS), + mafmt.Base(ma.P_TCP), + mafmt.Or( + mafmt.Base(ma.P_WS), + mafmt.And( + mafmt.Or( + mafmt.And( + mafmt.Base(ma.P_TLS), + mafmt.Base(ma.P_SNI)), + mafmt.Base(ma.P_TLS), + ), + mafmt.Base(ma.P_WS)), + mafmt.Base(ma.P_WSS))) + +var ( + wssComponent = ma.StringCast("/wss") + tlsWsComponent = ma.StringCast("/tls/ws") + tlsComponent = ma.StringCast("/tls") + wsComponent = ma.StringCast("/ws") +) func init() { manet.RegisterFromNetAddr(ParseWebsocketNetAddr, "websocket") @@ -100,6 +118,42 @@ func (t *WebsocketTransport) Proxy() bool { return false } +func (t *WebsocketTransport) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { + parsed, err := parseWebsocketMultiaddr(maddr) + if err != nil { + return nil, err + } + + if !parsed.isWSS { + // No /tls/ws component, this isn't a secure websocket multiaddr. We can just return it here + return []ma.Multiaddr{maddr}, nil + } + + if parsed.sni == nil { + var err error + // We don't have an sni component, we'll use dns/dnsaddr + ma.ForEach(parsed.restMultiaddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_DNS, ma.P_DNS4, ma.P_DNS6, ma.P_DNSADDR: + // err shouldn't happen since this means we couldn't parse a dns hostname for an sni value. + parsed.sni, err = ma.NewComponent("sni", c.Value()) + return false + } + return true + }) + if err != nil { + return nil, err + } + } + + if parsed.sni == nil { + // we didn't find anything to set the sni with. So we just return the given multiaddr + return []ma.Multiaddr{maddr}, nil + } + + return []ma.Multiaddr{parsed.toMultiaddr()}, nil +} + func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { connScope, err := t.rcmgr.OpenConnection(network.DirOutbound, true, raddr) if err != nil { @@ -121,9 +175,21 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma isWss := wsurl.Scheme == "wss" dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second} if isWss { - dialer.TLSClientConfig = t.tlsClientConf + sni := "" + sni, err = raddr.ValueForProtocol(ma.P_SNI) + if err != nil { + sni = "" + } + if sni != "" { + copytlsClientConf := t.tlsClientConf.Clone() + copytlsClientConf.ServerName = sni + dialer.TLSClientConfig = copytlsClientConf + } else { + dialer.TLSClientConfig = t.tlsClientConf + } } + wscon, _, err := dialer.DialContext(ctx, wsurl.String(), nil) if err != nil { return nil, err diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 77dda801f8..9ac71a7796 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -2,6 +2,8 @@ package websocket import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -24,6 +26,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/security/noise" ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" @@ -31,6 +34,16 @@ import ( ) func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { + t.Helper() + id, m := newInsecureMuxer(t) + u, err := tptu.New(m, yamux.DefaultTransport) + if err != nil { + t.Fatal(err) + } + return id, u +} + +func newSecureUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { t.Helper() id, m := newSecureMuxer(t) u, err := tptu.New(m, yamux.DefaultTransport) @@ -40,7 +53,7 @@ func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { return id, u } -func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { +func newInsecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Helper() priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) if err != nil { @@ -55,15 +68,32 @@ func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { return id, &secMuxer } +func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { + t.Helper() + priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) + if err != nil { + t.Fatal(err) + } + id, err := peer.IDFromPrivateKey(priv) + if err != nil { + t.Fatal(err) + } + var secMuxer csms.SSMuxer + noiseTpt, err := noise.New(priv) + require.NoError(t, err) + secMuxer.AddTransport(noise.ID, noiseTpt) + return id, &secMuxer +} + func lastComponent(t *testing.T, a ma.Multiaddr) ma.Multiaddr { t.Helper() _, wscomponent := ma.SplitLast(a) require.NotNil(t, wscomponent) - if wscomponent.Equal(wsma) { - return wsma + if wscomponent.Equal(wsComponent) { + return wsComponent } - if wscomponent.Equal(wssma) { - return wssma + if wscomponent.Equal(wssComponent) { + return wssComponent } t.Fatal("expected a ws or wss component") return nil @@ -102,33 +132,117 @@ func TestCanDial(t *testing.T) { if d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555")) { t.Fatal("expected to not match tcp maddr, but did") } + if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/ws")) { + t.Fatal("expected to match secure websocket maddr, but did not") + } + if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/sni/example.com/ws")) { + t.Fatal("expected to match secure websocket maddr with sni, but did not") + } + if !d.CanDial(ma.StringCast("/dns4/example.com/tcp/5555/tls/sni/example.com/ws")) { + t.Fatal("expected to match secure websocket maddr with sni, but did not") + } + if !d.CanDial(ma.StringCast("/dnsaddr/example.com/tcp/5555/tls/sni/example.com/ws")) { + t.Fatal("expected to match secure websocket maddr with sni, but did not") + } } -func TestDialWss(t *testing.T) { - if _, err := net.LookupIP("nyc-1.bootstrap.libp2p.io"); err != nil { - t.Skip("this test requries an internet connection and it seems like we currently don't have one") +// testWSSServer returns a client hello info +func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID, chan error) { + errChan := make(chan error, 1) + + ip := net.ParseIP("::") + tlsConf := getTLSConf(t, ip, time.Now(), time.Now().Add(time.Hour)) + tlsConf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + if chi.ServerName != "example.com" { + errChan <- fmt.Errorf("didn't get the expected sni") + } + return tlsConf, nil } - raddr := ma.StringCast("/dns4/nyc-1.bootstrap.libp2p.io/tcp/443/wss") - rid, err := peer.Decode("QmSoLueR4xBeUbY9WZ9xGUUxunbKWcrNFTDAadQJmocnWm") + + id, u := newSecureUpgrader(t) + tpt, err := New(u, network.NullResourceManager, WithTLSConfig(tlsConf)) if err != nil { t.Fatal(err) } - tlsConfig := &tls.Config{InsecureSkipVerify: true} - _, u := newUpgrader(t) - tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig)) - if err != nil { - t.Fatal(err) + // l, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")) + // l, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/ws")) + l, err := tpt.Listen(listenAddr) + fmt.Println("here", listenAddr) + require.NoError(t, err) + t.Cleanup(func() { + l.Close() + }) + go func() { + conn, err := l.Accept() + if err != nil { + errChan <- fmt.Errorf("error in accepting conn: %w", err) + return + } + defer conn.Close() + + strm, err := conn.AcceptStream() + if err != nil { + errChan <- fmt.Errorf("error in accepting stream: %w", err) + return + } + defer strm.Close() + close(errChan) + }() + + return l.Multiaddr(), id, errChan +} + +func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { + t.Helper() + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1234), + Subject: pkix.Name{Organization: []string{"websocket"}}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, } - conn, err := tpt.Dial(context.Background(), raddr, rid) - if err != nil { - t.Fatal(err) + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv) + require.NoError(t, err) + cert, err := x509.ParseCertificate(caBytes) + require.NoError(t, err) + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: priv, + Leaf: cert, + }}, } +} + +func TestDialWss(t *testing.T) { + serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws")) + require.Contains(t, serverMA.String(), "tls") + + tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA + _, u := newSecureUpgrader(t) + tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig)) + require.NoError(t, err) + + masToDial, err := tpt.Resolve(context.Background(), serverMA) + require.NoError(t, err) + + conn, err := tpt.Dial(context.Background(), masToDial[0], rid) + require.NoError(t, err) + defer conn.Close() + stream, err := conn.OpenStream(context.Background()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer stream.Close() + + err = <-errChan + require.NoError(t, err) } func TestWebsocketTransport(t *testing.T) { @@ -160,9 +274,9 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { l, err := tpt.Listen(laddr) require.NoError(t, err) if secure { - require.Equal(t, lastComponent(t, l.Multiaddr()), wssma) + require.Contains(t, l.Multiaddr().String(), "tls") } else { - require.Equal(t, lastComponent(t, l.Multiaddr()), wsma) + require.Equal(t, lastComponent(t, l.Multiaddr()), wsComponent) } defer l.Close() @@ -234,8 +348,8 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { conn, err := client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() - require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsma.String()) - require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsma.String()) + require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsComponent.String()) + require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsComponent.String()) // dialing the secure address should fail _, err = client.Dial(context.Background(), lnSecure.Multiaddr(), serverID) @@ -251,8 +365,8 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { conn, err := client.Dial(context.Background(), lnSecure.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() - require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()), wssma) - require.Equal(t, lastComponent(t, conn.LocalMultiaddr()), wssma) + require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()), wssComponent) + require.Equal(t, lastComponent(t, conn.LocalMultiaddr()), wssComponent) // dialing the insecure address should fail _, err = client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID) @@ -346,3 +460,30 @@ func TestWriteZero(t *testing.T) { t.Errorf("expected EOF, got err: %s", err) } } + +func TestResolveMultiaddr(t *testing.T) { + // map[unresolved]resolved + testCases := map[string]string{ + "/dns4/example.com/tcp/1234/wss": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws", + "/dns6/example.com/tcp/1234/wss": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws", + "/dnsaddr/example.com/tcp/1234/wss": "/dnsaddr/example.com/tcp/1234/tls/sni/example.com/ws", + "/dns4/example.com/tcp/1234/tls/ws": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws", + "/dns6/example.com/tcp/1234/tls/ws": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws", + "/dnsaddr/example.com/tcp/1234/tls/ws": "/dnsaddr/example.com/tcp/1234/tls/sni/example.com/ws", + } + + for unresolved, expectedMA := range testCases { + t.Run(unresolved, func(t *testing.T) { + + m1 := ma.StringCast(unresolved) + wsTpt := WebsocketTransport{} + ctx := context.Background() + + addrs, err := wsTpt.Resolve(ctx, m1) + require.NoError(t, err) + require.Len(t, addrs, 1) + + require.Equal(t, expectedMA, addrs[0].String()) + }) + } +}