diff --git a/config/config.go b/config/config.go index 8be5a43999..69f2936612 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,8 @@ import ( relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" + "github.com/pion/webrtc/v3" "github.com/prometheus/client_golang/prometheus" ma "github.com/multiformats/go-multiaddr" @@ -65,6 +67,8 @@ type Security struct { Constructor interface{} } +type ICEServer = webrtc.ICEServer + // Config describes a set of settings for a libp2p node // // This is *not* a stable interface. Use the options defined in the root @@ -128,6 +132,9 @@ type Config struct { DialRanker network.DialRanker SwarmOpts []swarm.Option + + WebRTCPrivate bool + WebRTCStunServers []ICEServer } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -208,6 +215,7 @@ func (cfg *Config) addTransports(h host.Host) error { fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), + fx.Provide(func() []ICEServer { return cfg.WebRTCStunServers }), } fxopts = append(fxopts, cfg.Transports...) if cfg.Insecure { @@ -284,6 +292,9 @@ func (cfg *Config) addTransports(h host.Host) error { if cfg.Relay { fxopts = append(fxopts, fx.Invoke(circuitv2.AddTransport)) } + if cfg.WebRTCPrivate { + fxopts = append(fxopts, fx.Invoke(libp2pwebrtcprivate.AddTransport)) + } app := fx.New(fxopts...) if err := app.Err(); err != nil { h.Close() diff --git a/core/network/context.go b/core/network/context.go index 7fabfb53e0..41c72b99e7 100644 --- a/core/network/context.go +++ b/core/network/context.go @@ -29,6 +29,13 @@ func WithForceDirectDial(ctx context.Context, reason string) context.Context { return context.WithValue(ctx, forceDirectDial, reason) } +// WithoutForceDirectDial constructs a new context with the ForceDirectDial option dropped. +// This is useful in case establishing a direct connection first requires establishing a +// relayed connection e.g. dialing /webrtc addresses. +func WithoutForceDirectDial(ctx context.Context) context.Context { + return context.WithValue(ctx, forceDirectDial, nil) +} + // EXPERIMENTAL // GetForceDirectDial returns true if the force direct dial option is set in the context. func GetForceDirectDial(ctx context.Context) (forceDirect bool, reason string) { diff --git a/options.go b/options.go index 1a1e9d3982..4f9ac2597b 100644 --- a/options.go +++ b/options.go @@ -598,3 +598,11 @@ func SwarmOpts(opts ...swarm.Option) Option { return nil } } + +func EnableWebRTCPrivate(stunServers []config.ICEServer) Option { + return func(cfg *Config) error { + cfg.WebRTCPrivate = true + cfg.WebRTCStunServers = stunServers + return nil + } +} diff --git a/p2p/host/autorelay/relay_finder.go b/p2p/host/autorelay/relay_finder.go index ef79950b7b..de041b9ad9 100644 --- a/p2p/host/autorelay/relay_finder.go +++ b/p2p/host/autorelay/relay_finder.go @@ -726,7 +726,7 @@ func (rf *relayFinder) relayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { // only keep private addrs from the original addr set for _, addr := range addrs { - if manet.IsPrivateAddr(addr) { + if !manet.IsPublicAddr(addr) { raddrs = append(raddrs, addr) } } diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 6c3ba53e5b..5659da48ec 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -27,6 +27,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/protocol/ping" + libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/prometheus/client_golang/prometheus" @@ -801,9 +802,38 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { addrs[i] = addrWithCerthash } } + + // Append webrtc addresses to circuit-v2 addresses + hasWebRTCPrivate := false + for _, addr := range addrs { + if addr.Equal(libp2pwebrtcprivate.WebRTCAddr) { + hasWebRTCPrivate = true + break + } + } + if hasWebRTCPrivate { + for _, addr := range addrs { + if _, err := addr.ValueForProtocol(ma.P_CIRCUIT); err == nil { + if isBrowserDialableAddr(addr) { + addrs = append(addrs, addr.Encapsulate(libp2pwebrtcprivate.WebRTCAddr)) + } + } + } + } return addrs } +var browserProtocols = []int{ma.P_WEBTRANSPORT, ma.P_WEBRTC_DIRECT, ma.P_WSS} + +func isBrowserDialableAddr(addr ma.Multiaddr) bool { + for _, p := range browserProtocols { + if _, err := addr.ValueForProtocol(p); err == nil { + return true + } + } + return false +} + // NormalizeMultiaddr returns a multiaddr suitable for equality checks. // If the multiaddr is a webtransport component, it removes the certhashes. func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { diff --git a/p2p/net/swarm/dial_ranker.go b/p2p/net/swarm/dial_ranker.go index 7e58876b91..bd806f8316 100644 --- a/p2p/net/swarm/dial_ranker.go +++ b/p2p/net/swarm/dial_ranker.go @@ -43,7 +43,10 @@ func NoDelayDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { // no additional latency in the vast majority of cases. // // Private and public address groups are dialed in parallel. +// // Dialing relay addresses is delayed by 500 ms, if we have any non-relay alternatives. +// We treat webrtc addresses the same as relay addresses as we need a relay connection to establish a +// webrtc connection. So any available direct addresses are preferred over webrtc addresses. // // Within each group (private, public, relay addresses) we apply the following ranking logic: // @@ -72,7 +75,8 @@ func NoDelayDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { // // We dial lowest ports first as they are more likely to be the listen port. func DefaultDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { - relay, addrs := filterAddrs(addrs, isRelayAddr) + // includes /webrtc addresses too + relay, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_CIRCUIT) }) pvt, addrs := filterAddrs(addrs, manet.IsPrivateAddr) public, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_IP4) || isProtocolAddr(a, ma.P_IP6) }) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 288ad9cc7d..9543ac0db1 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -597,11 +597,6 @@ func isFdConsumingAddr(addr ma.Multiaddr) bool { return err1 == nil || err2 == nil } -func isRelayAddr(addr ma.Multiaddr) bool { - _, err := addr.ValueForProtocol(ma.P_CIRCUIT) - return err == nil -} - // filterLowPriorityAddresses removes addresses inplace for which we have a better alternative // 1. If a /quic-v1 address is present, filter out /quic and /webtransport address on the same 2-tuple: // QUIC v1 is preferred over the deprecated QUIC draft-29, and given the choice, we prefer using diff --git a/p2p/net/swarm/swarm_transport.go b/p2p/net/swarm/swarm_transport.go index 924f0384aa..e36b7d0925 100644 --- a/p2p/net/swarm/swarm_transport.go +++ b/p2p/net/swarm/swarm_transport.go @@ -27,7 +27,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { } return nil } - if isRelayAddr(a) { + if isProtocolAddr(a, ma.P_WEBRTC) { + return s.transports.m[ma.P_WEBRTC] + } + if isProtocolAddr(a, ma.P_CIRCUIT) { return s.transports.m[ma.P_CIRCUIT] } for _, t := range s.transports.m { diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 29d589cd7a..3275dc49cc 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" @@ -511,3 +512,57 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, require.NoError(t, err) return h, hps } + +func TestWebRTCDirectConnect(t *testing.T) { + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relayv2.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + + h1, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + libp2p.EnableHolePunching(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + ) + require.NoError(t, err) + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc") + relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/") + h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL) + + err = h1.Connect(context.Background(), peer.AddrInfo{ID: h2.ID()}) + require.NoError(t, err) + require.Eventually( + t, + func() bool { + for _, c := range h1.Network().ConnsToPeer(h2.ID()) { + if !c.Stat().Transient { + return true + } + } + return false + }, + 5*time.Second, + 100*time.Millisecond, + ) +} diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index b651bd7822..cdcba85186 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -108,6 +108,9 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { // short-circuit hole punching if a direct dial works. // attempt a direct connection ONLY if we have a public address for the remote peer for _, a := range hp.host.Peerstore().Addrs(rp) { + // Here we consider /webrtc addresses as relay addresses and skip them as they're + // also holepunched. We will dial the /webrtc addresses along with other addresses + // obtained in DCUtR if manet.IsPublicAddr(a) && !isRelayAddress(a) { forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching") dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) @@ -136,6 +139,7 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { if err != nil { log.Debugw("hole punching failed", "peer", rp, "error", err) hp.tracer.ProtocolError(rp, err) + hp.maybeDialWebRTC(rp) return err } synTime := rtt / 2 @@ -171,6 +175,20 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { return fmt.Errorf("all retries for hole punch with peer %s failed", rp) } +func (hp *holePuncher) maybeDialWebRTC(p peer.ID) { + addrs := hp.host.Peerstore().Addrs(p) + for _, a := range addrs { + if _, err := a.ValueForProtocol(ma.P_WEBRTC); err == nil { + ctx := network.WithForceDirectDial(hp.ctx, "webrtc holepunch") + err := hp.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + if err != nil { + log.Debugf("holepunch attempt to %s over /webrtc failed: %s", p, err) + } + return + } + } +} + // initiateHolePunch opens a new hole punching coordination stream, // exchanges the addresses and measures the RTT. func (hp *holePuncher) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, []ma.Multiaddr, time.Duration, error) { diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go index 47bf434fb1..1796ec9701 100644 --- a/p2p/protocol/holepunch/svc.go +++ b/p2p/protocol/holepunch/svc.go @@ -84,6 +84,8 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, return nil, err } } + s.host.Network().Notify(s) + s.tracer.Start() s.refCount.Add(1) @@ -283,3 +285,42 @@ func (s *Service) DirectConnect(p peer.ID) error { s.holePuncherMx.Unlock() return holePuncher.DirectConnect(p) } + +var _ network.Notifiee = &Service{} + +func (s *Service) Connected(_ network.Network, conn network.Conn) { + // Dial /webrtc address if it's a relay connection to a browser node + if conn.Stat().Direction == network.DirOutbound && conn.Stat().Transient { + s.refCount.Add(1) + go func() { + defer s.refCount.Done() + select { + // waiting for Identify here will allow us to access the peer's public and observed addresses + // that we can dial to for a hole punch. + case <-s.ids.IdentifyWait(conn): + case <-s.ctx.Done(): + return + } + p := conn.RemotePeer() + // Peer supports DCUtR, let it trigger holepunch + if protos, err := s.host.Peerstore().SupportsProtocols(p, Protocol); err == nil && len(protos) > 0 { + return + } + // No DCUtR support, connect with peer over /webrtc + for _, addr := range s.host.Peerstore().Addrs(p) { + if _, err := addr.ValueForProtocol(ma.P_WEBRTC); err == nil { + ctx := network.WithForceDirectDial(s.ctx, "webrtc holepunch") + err := s.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + if err != nil { + log.Debugf("holepunch attempt to %s over /webrtc failed: %s", p, err) + } + return + } + } + }() + } +} + +func (*Service) Disconnected(_ network.Network, v network.Conn) {} +func (*Service) Listen(n network.Network, a ma.Multiaddr) {} +func (*Service) ListenClose(n network.Network, a ma.Multiaddr) {} diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index d00dd9d5dc..0d621884d7 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -10,6 +10,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/host/autorelay" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" ma "github.com/multiformats/go-multiaddr" @@ -158,3 +159,43 @@ func TestNewStreamTransientConnection(t *testing.T) { <-done <-done } + +func TestWebRTCPrivateAddressAdvertisement(t *testing.T) { + r, err := libp2p.New( + // We need a public address for the relay + libp2p.AddrsFactory(func(addrs []ma.Multiaddr) []ma.Multiaddr { + return append(addrs, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport")) + }), + libp2p.EnableRelayService(), + libp2p.ForceReachabilityPublic(), + ) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: r.ID(), + Addrs: r.Addrs(), + } + + h, err := libp2p.New( + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + libp2p.EnableAutoRelayWithStaticRelays( + []peer.AddrInfo{relay1info}, + autorelay.WithBootDelay(0), + ), + libp2p.ForceReachabilityPrivate(), + ) + require.NoError(t, err) + + require.Eventually(t, func() bool { + for _, a := range h.Addrs() { + _, rerr := a.ValueForProtocol(ma.P_CIRCUIT) + _, werr := a.ValueForProtocol(ma.P_WEBRTC) + if rerr == nil && werr == nil { + return true + } + } + return false + }, 5*time.Second, 50*time.Millisecond) +} diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 9874431441..5fdafb2913 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -243,3 +243,78 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { return false }, 5*time.Second, 100*time.Millisecond) } + +func TestDialPeerWebRTC(t *testing.T) { + h1, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + ) + require.NoError(t, err) + + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relay.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc/p2p/" + h2.ID().String()) + relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) + + h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL) + + // swarm.DialPeer should connect over transient connections + conn1, err := h1.Network().DialPeer(context.Background(), h2.ID()) + require.NoError(t, err) + require.NotNil(t, conn1) + require.Condition(t, func() bool { + _, err1 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) + _, err2 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) + return err1 == nil && err2 != nil + }) + + // should connect to webrtc address + ctx := network.WithForceDirectDial(context.Background(), "test") + conn, err := h1.Network().DialPeer(ctx, h2.ID()) + require.NoError(t, err) + require.NotNil(t, conn) + require.Condition(t, func() bool { + _, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) + _, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) + return err1 != nil && err2 == nil + }) + + done := make(chan struct{}) + h2.SetStreamHandler("test-addr", func(s network.Stream) { + s.Conn().LocalMultiaddr() + _, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) + assert.Error(t, err1) + _, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) + assert.NoError(t, err2) + s.Reset() + close(done) + }) + + s, err := h1.NewStream(context.Background(), h2.ID(), "test-addr") + require.NoError(t, err) + s.Write([]byte("test")) + <-done +} diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..c914b04d31 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -83,6 +83,10 @@ func TestInterceptSecuredOutgoing(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + } + ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -92,7 +96,6 @@ func TestInterceptSecuredOutgoing(t *testing.T) { defer h1.Close() defer h2.Close() require.Len(t, h2.Addrs(), 1) - require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -104,6 +107,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) { require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) }), ) + err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) @@ -117,6 +121,9 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -153,6 +160,9 @@ func TestInterceptAccept(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -198,6 +208,10 @@ func TestInterceptSecuredIncoming(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } + ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -231,6 +245,9 @@ func TestInterceptUpgradedIncoming(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) diff --git a/p2p/test/transport/rcmgr_test.go b/p2p/test/transport/rcmgr_test.go index 20f34de799..1ea2a0a69e 100644 --- a/p2p/test/transport/rcmgr_test.go +++ b/p2p/test/transport/rcmgr_test.go @@ -24,7 +24,9 @@ func TestResourceManagerIsUsed(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { for _, testDialer := range []bool{true, false} { t.Run(tc.Name+fmt.Sprintf(" test_dialer=%v", testDialer), func(t *testing.T) { - + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + } var reservedMemory, releasedMemory atomic.Int32 defer func() { require.Equal(t, reservedMemory.Load(), releasedMemory.Load()) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index a7e98a0d85..16e45e6a5f 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -25,12 +25,14 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" "github.com/libp2p/go-libp2p/p2p/net/swarm" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" - "github.com/multiformats/go-multiaddr" + ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -152,6 +154,56 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "WebRTCPrivate", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + // NoListenAddrs helps ensure that we are not listening for TCP, QUIC etc. We do need + // those transports to dial the relay for signaling stream + libp2pOpts = append(libp2pOpts, libp2p.EnableWebRTCPrivate(nil), libp2p.EnableRelay(), libp2p.NoListenAddrs) + + if !opts.NoListen { + r, err := libp2p.New( + libp2p.EnableRelayService(), + libp2p.ForceReachabilityPublic(), + libp2p.Transport(libp2pquic.NewTransport), + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1")) + require.NoError(t, err) + libp2pOpts = append( + libp2pOpts, + libp2p.AddrsFactory(func(_ []ma.Multiaddr) []ma.Multiaddr { + raddrs := r.Addrs() + addrs := make([]ma.Multiaddr, len(raddrs)) + for i := 0; i < len(raddrs); i++ { + + addrs[i] = ma.StringCast(fmt.Sprintf("%s/p2p/%s/p2p-circuit/webrtc/", raddrs[i], r.ID())) + } + return addrs + })) + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + _, err = client.Reserve(context.Background(), h, peer.AddrInfo{ID: r.ID(), Addrs: r.Addrs()}) + require.NoError(t, err) + return &webrtcHost{Host: h, r: r} + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return &webrtcHost{Host: h} + }, + }, +} + +type webrtcHost struct { + host.Host + r host.Host +} + +func (h *webrtcHost) Close() error { + h.Host.Close() + if h.r != nil { + h.r.Close() + } + return nil } func TestPing(t *testing.T) { @@ -656,6 +708,10 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + } + h1 := tc.HostGenerator(t, TransportTestCaseOpts{}) h2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) defer h1.Close() @@ -673,7 +729,7 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { ai := &peer.AddrInfo{ ID: bogusPeerId, - Addrs: []multiaddr.Multiaddr{h1.Addrs()[0]}, + Addrs: []ma.Multiaddr{h1.Addrs()[0]}, } // Try connecting with the bogus peer ID diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index fd31f8351a..4cf853b08d 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -37,14 +37,14 @@ func (errConnectionTimeout) Error() string { return "connection timeout" } func (errConnectionTimeout) Timeout() bool { return true } func (errConnectionTimeout) Temporary() bool { return false } -type dataChannel struct { +type DetachedDataChannel struct { stream datachannel.ReadWriteCloser channel *webrtc.DataChannel } type connection struct { pc *webrtc.PeerConnection - transport *WebRTCTransport + transport tpt.Transport scope network.ConnManagementScope closeErr error @@ -56,20 +56,22 @@ type connection struct { remoteKey ic.PubKey remoteMultiaddr ma.Multiaddr + connectionState network.ConnectionState + m sync.Mutex streams map[uint16]*stream nextStreamID atomic.Int32 - acceptQueue chan dataChannel - - ctx context.Context - cancel context.CancelFunc + acceptQueue chan DetachedDataChannel + ctx context.Context + cancel context.CancelFunc } -func newConnection( +// NewWebRTCConnection creates a transport.CapableConn from a webrtc.PeerConnection +func NewWebRTCConnection( direction network.Direction, pc *webrtc.PeerConnection, - transport *WebRTCTransport, + transport tpt.Transport, scope network.ConnManagementScope, localPeer peer.ID, @@ -78,8 +80,13 @@ func newConnection( remotePeer peer.ID, remoteKey ic.PubKey, remoteMultiaddr ma.Multiaddr, + datachannelQueue chan DetachedDataChannel, ) (*connection, error) { ctx, cancel := context.WithCancel(context.Background()) + connectionState := network.ConnectionState{Transport: "webrtc"} + if _, ok := transport.(*WebRTCTransport); ok { + connectionState = network.ConnectionState{Transport: "webrtc-direct"} + } c := &connection{ pc: pc, transport: transport, @@ -91,11 +98,13 @@ func newConnection( remotePeer: remotePeer, remoteKey: remoteKey, remoteMultiaddr: remoteMultiaddr, - ctx: ctx, - cancel: cancel, - streams: make(map[uint16]*stream), - acceptQueue: make(chan dataChannel, maxAcceptQueueLen), + connectionState: connectionState, + + ctx: ctx, + cancel: cancel, + streams: make(map[uint16]*stream), + acceptQueue: datachannelQueue, } switch direction { case network.DirInbound: @@ -106,40 +115,21 @@ func newConnection( } pc.OnConnectionStateChange(c.onConnectionStateChange) - pc.OnDataChannel(func(dc *webrtc.DataChannel) { - if c.IsClosed() { - return - } - // Limit the number of streams, since we're not able to actually properly close them. - // See https://github.com/libp2p/specs/issues/575 for details. - if *dc.ID() > maxDataChannelID { - c.Close() - return - } - dc.OnOpen(func() { - rwc, err := dc.Detach() - if err != nil { - log.Warnf("could not detach datachannel: id: %d", *dc.ID()) - return - } - select { - case c.acceptQueue <- dataChannel{rwc, dc}: - default: - log.Warnf("connection busy, rejecting stream") - b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()}) - w := msgio.NewWriter(rwc) - w.WriteMsg(b) - rwc.Close() - } - }) - }) + + // Between the connection establishing and the callback update in the above line, the + // connection may have been closed + state := pc.ConnectionState() + if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + pc.Close() + return nil, errors.New("connection closed") + } return c, nil } // ConnState implements transport.CapableConn func (c *connection) ConnState() network.ConnectionState { - return network.ConnectionState{Transport: "webrtc-direct"} + return c.connectionState } // Close closes the underlying peerconnection. @@ -315,3 +305,38 @@ func (c *connection) setRemotePeer(id peer.ID) { func (c *connection) setRemotePublicKey(key ic.PubKey) { c.remoteKey = key } + +// SetupDataChannelQueue sets callback on the peer connection to push incoming +// data channels on to the returned queue after detaching the data channel. +// +// We need to ensure that the data channel is enqueued from the onOpen callback +// to avoid a race condition in pion: https://github.com/pion/webrtc/issues/2586 +func SetupDataChannelQueue(pc *webrtc.PeerConnection, queueLen int) chan DetachedDataChannel { + queue := make(chan DetachedDataChannel, queueLen) + pc.OnDataChannel(func(dc *webrtc.DataChannel) { + // Limit the number of streams, since we're not able to actually properly close them. + // See https://github.com/libp2p/specs/issues/575 for details. + if *dc.ID() > maxDataChannelID { + dc.Close() + return + } + dc.OnOpen(func() { + rwc, err := dc.Detach() + if err != nil { + log.Warnf("could not detach datachannel: id: %d", *dc.ID()) + return + } + select { + case queue <- DetachedDataChannel{rwc, dc}: + default: + log.Warnf("connection busy, rejecting stream") + b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()}) + w := msgio.NewWriter(rwc) + w.WriteMsg(b) + rwc.Close() + } + }) + + }) + return queue +} diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 0b29bf655d..bec3dbba35 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -25,14 +25,14 @@ import ( "go.uber.org/zap/zapcore" ) -type connMultiaddrs struct { - local, remote ma.Multiaddr +type ConnMultiaddrs struct { + Local, Remote ma.Multiaddr } -var _ network.ConnMultiaddrs = &connMultiaddrs{} +var _ network.ConnMultiaddrs = &ConnMultiaddrs{} -func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } -func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } +func (c *ConnMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.Local } +func (c *ConnMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.Remote } const ( candidateSetupTimeout = 20 * time.Second @@ -158,7 +158,7 @@ func (l *listener) handleCandidate(ctx context.Context, candidate udpmux.Candida } if l.transport.gater != nil { localAddr, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) - if !l.transport.gater.InterceptAccept(&connMultiaddrs{local: localAddr, remote: remoteMultiaddr}) { + if !l.transport.gater.InterceptAccept(&ConnMultiaddrs{Local: localAddr, Remote: remoteMultiaddr}) { // The connection attempt is rejected before we can send the client an error. // This means that the connection attempt will time out. return nil, errors.New("connection gated") @@ -229,6 +229,7 @@ func (l *listener) setupConnection( if err != nil { return nil, err } + dataChannelQueue := SetupDataChannelQueue(pc, maxAcceptQueueLen) negotiated, id := handshakeChannelNegotiated, handshakeChannelID rawDatachannel, err := pc.CreateDataChannel("", &webrtc.DataChannelInit{ @@ -275,7 +276,7 @@ func (l *listener) setupConnection( // The connection is instantiated before performing the Noise handshake. This is // to handle the case where the remote is faster and attempts to initiate a stream // before the ondatachannel callback can be set. - conn, err := newConnection( + conn, err := NewWebRTCConnection( network.DirInbound, pc, l.transport, @@ -285,6 +286,7 @@ func (l *listener) setupConnection( "", // remotePeer nil, // remoteKey remoteMultiaddr, + dataChannelQueue, ) if err != nil { return nil, err diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index dd4028d1f2..8b2a722087 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -321,6 +321,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, fmt.Errorf("instantiate peerconnection: %w", err) } + dataChannelQueue := SetupDataChannelQueue(pc, maxAcceptQueueLen) errC := addOnConnectionStateChangeCallback(pc) // We need to set negotiated = true for this channel on both @@ -392,7 +393,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement // we can only know the remote public key after the noise handshake, // but need to set up the callbacks on the peerconnection - conn, err := newConnection( + conn, err := NewWebRTCConnection( network.DirOutbound, pc, t, @@ -402,6 +403,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement p, nil, remoteMultiaddrWithoutCerthash, + dataChannelQueue, ) if err != nil { return nil, err diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go new file mode 100644 index 0000000000..97807e3ef5 --- /dev/null +++ b/p2p/transport/webrtcprivate/listener.go @@ -0,0 +1,319 @@ +package libp2pwebrtcprivate + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/libp2p/go-libp2p/core/network" + tpt "github.com/libp2p/go-libp2p/core/transport" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "github.com/pion/webrtc/v3" +) + +type listener struct { + transport *transport + connQueue chan tpt.CapableConn + inflightQueue chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +var _ tpt.Listener = &listener{} + +type NetAddr struct{} + +var _ net.Addr = NetAddr{} + +func (n NetAddr) Network() string { + return "libp2p-webrtc" +} + +func (n NetAddr) String() string { + return "/webrtc" +} + +// Accept implements transport.Listener. +func (l *listener) Accept() (tpt.CapableConn, error) { + if l.ctx.Err() != nil { + return nil, tpt.ErrListenerClosed + } + select { + case c := <-l.connQueue: + return c, nil + case <-l.ctx.Done(): + return nil, tpt.ErrListenerClosed + } +} + +// Addr implements transport.Listener. The returned address always returns libp2p-webrtc:/webrtc +func (l *listener) Addr() net.Addr { + return NetAddr{} +} + +func (l *listener) Close() error { + l.transport.RemoveListener(l) + l.cancel() + return nil +} + +func (*listener) Multiaddr() ma.Multiaddr { + return ma.StringCast("/webrtc") +} + +func (l *listener) handleSignalingStream(s network.Stream) { + select { + case l.inflightQueue <- struct{}{}: + defer func() { <-l.inflightQueue }() + case <-l.ctx.Done(): + s.Reset() + return + } + + ctx, cancel := context.WithTimeout(context.Background(), connectTimeout) + defer cancel() + defer s.Close() + + scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, true, ma.StringCast("/webrtc")) // we don't have a better remote adress right now + if err != nil { + s.Reset() + log.Debug("failed to create connection scope:", err) + return + } + if err := scope.SetPeer(s.Conn().RemotePeer()); err != nil { + log.Debugf("resource manager blocked incoming conn from peer %s: %s", s.Conn().RemotePeer(), err) + return + } + + if err := s.Scope().SetService(name); err != nil { + log.Debugf("error attaching stream to /webrtc listener: %s", err) + s.Reset() + return + } + + if err := s.Scope().ReserveMemory(2*maxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("error reserving memory for /webrtc signaling stream: %s", err) + s.Reset() + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(connectTimeout)) + + if l.transport.gater != nil { + localAddr := s.Conn().LocalMultiaddr().Encapsulate(WebRTCAddr) + remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr) + if !l.transport.gater.InterceptAccept(&libp2pwebrtc.ConnMultiaddrs{Local: localAddr, Remote: remoteAddr}) { + log.Debug("gater disallowed accepting connection from %s at %s", s.Conn().RemotePeer(), remoteAddr) + s.Reset() + } + } + + conn, err := l.setupConnection(ctx, s, scope) + if err != nil { + s.Reset() + scope.Done() + log.Debug("failed to establish connection with %s: %s", s.Conn().RemotePeer(), err) + return + } + + if l.transport.gater != nil && !l.transport.gater.InterceptSecured(network.DirInbound, s.Conn().RemotePeer(), conn) { + conn.Close() + log.Debugf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) + } + // Close the stream before we wait for the connection to be accepted + s.Close() + select { + case l.connQueue <- conn: + case <-l.ctx.Done(): + conn.Close() + log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) + } +} + +func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { + + pc, err := l.transport.NewPeerConnection() + if err != nil { + err = fmt.Errorf("error creating a webrtc.PeerConnection: %w", err) + log.Debug(err) + return nil, err + } + dataChannelQueue := libp2pwebrtc.SetupDataChannelQueue(pc, maxAcceptQueueLen) + + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + // register peerconnection state update callback + connectionState := make(chan webrtc.PeerConnectionState, 1) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateConnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: + // We only use the first state written to connectionState. + select { + case connectionState <- state: + default: + } + } + }) + + // register local ICE Candidate found callback + writeErr := make(chan error, 1) + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + b, err := json.Marshal(candidate.ToJSON()) + if err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("failed to marshal candidate to JSON: %w", err): + default: + } + return + } + data := string(b) + + msg := &pb.Message{ + Type: pb.Message_ICE_CANDIDATE.Enum(), + Data: &data, + } + if err := w.WriteMsg(msg); err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("write candidate failed: %w", err): + default: + } + } + }) + + // de-register candidate callback + defer pc.OnICECandidate(func(_ *webrtc.ICECandidate) {}) + + // read an incoming offer + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + err = fmt.Errorf("failed to read offer: %w", err) + return nil, err + } + if msg.Type == nil || *msg.Type != pb.Message_SDP_OFFER { + err = fmt.Errorf("invalid message: msg.Type expected %s got %s", pb.Message_SDP_OFFER, msg.Type) + return nil, err + } + if msg.Data == nil || *msg.Data == "" { + err = errors.New("invalid message: empty data") + return nil, err + } + offer := webrtc.SessionDescription{Type: webrtc.SDPTypeOffer, SDP: *msg.Data} + if err := pc.SetRemoteDescription(offer); err != nil { + err = fmt.Errorf("failed to set remote description: %w", err) + return nil, err + } + + // send an answer + answer, err := pc.CreateAnswer(nil) + if err != nil { + return nil, fmt.Errorf("failed to create answer: %w", err) + } + msg = pb.Message{ + Type: pb.Message_SDP_ANSWER.Enum(), + Data: &answer.SDP, + } + if err := w.WriteMsg(&msg); err != nil { + return nil, fmt.Errorf("failed to write answer: %w", err) + } + if err := pc.SetLocalDescription(answer); err != nil { + return nil, fmt.Errorf("failed to set local description: %w", err) + } + + readErr := make(chan error, 1) + // start a goroutine to read candidates + go func() { + for { + if ctx.Err() != nil { + return + } + err := r.ReadMsg(&msg) + if err == io.EOF { + // remote has done writing + return + } + if err != nil { + readErr <- fmt.Errorf("failed to read candidate: %w", err) + return + } + + if msg.Type == nil || *msg.Type != pb.Message_ICE_CANDIDATE { + readErr <- fmt.Errorf("invalid message: msg.Type expected %s got %s", pb.Message_ICE_CANDIDATE, msg.Type) + return + } + // Ignore without Debuging on empty message. + // Pion has a case where OnCandidate callback may be called with a nil + // candidate + if msg.Data == nil || *msg.Data == "" { + log.Debugf("received empty candidate from %s", s.Conn().RemotePeer()) + continue + } + + var init webrtc.ICECandidateInit + if err := json.Unmarshal([]byte(*msg.Data), &init); err != nil { + readErr <- fmt.Errorf("failed to unmarshal ice candidate %w", err) + return + } + if err := pc.AddICECandidate(init); err != nil { + readErr <- fmt.Errorf("failed to add ice candidate: %w", err) + return + } + } + }() + + select { + case <-ctx.Done(): + pc.Close() + return nil, ctx.Err() + case err := <-writeErr: + pc.Close() + return nil, fmt.Errorf("error writing candidate: %w", err) + case err := <-readErr: + pc.Close() + return nil, fmt.Errorf("error reading candidate: %w", err) + case state := <-connectionState: + switch state { + default: + pc.Close() + return nil, fmt.Errorf("failed to establish webrtc.PeerConnection, state: %s", state) + case webrtc.PeerConnectionStateConnected: + } + } + + localAddr, remoteAddr, err := getConnectionAddresses(pc) + if err != nil { + pc.Close() + return nil, fmt.Errorf("failed to get connection addresses: %w", err) + } + + conn, err := libp2pwebrtc.NewWebRTCConnection( + network.DirInbound, + pc, + l.transport, + scope, + l.transport.host.ID(), + localAddr, + s.Conn().RemotePeer(), + l.transport.host.Peerstore().PubKey(s.Conn().RemotePeer()), // we have the public key from the relayed connection + remoteAddr, + dataChannelQueue, + ) + if err != nil { + pc.Close() + return nil, fmt.Errorf("error establishing tpt.CapableConn: %w", err) + } + return conn, nil +} diff --git a/p2p/transport/webrtcprivate/pb/generate.go b/p2p/transport/webrtcprivate/pb/generate.go new file mode 100644 index 0000000000..657f02bd6a --- /dev/null +++ b/p2p/transport/webrtcprivate/pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc --go_out=. --go_opt=paths=source_relative -I . msg.proto diff --git a/p2p/transport/webrtcprivate/pb/msg.pb.go b/p2p/transport/webrtcprivate/pb/msg.pb.go new file mode 100644 index 0000000000..337b4d7a19 --- /dev/null +++ b/p2p/transport/webrtcprivate/pb/msg.pb.go @@ -0,0 +1,220 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: msg.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Specifies type in `data` field. +type Message_Type int32 + +const ( + // String of `RTCSessionDescription.sdp` + Message_SDP_OFFER Message_Type = 0 + // String of `RTCSessionDescription.sdp` + Message_SDP_ANSWER Message_Type = 1 + // String of `RTCIceCandidate.toJSON()` + Message_ICE_CANDIDATE Message_Type = 2 +) + +// Enum value maps for Message_Type. +var ( + Message_Type_name = map[int32]string{ + 0: "SDP_OFFER", + 1: "SDP_ANSWER", + 2: "ICE_CANDIDATE", + } + Message_Type_value = map[string]int32{ + "SDP_OFFER": 0, + "SDP_ANSWER": 1, + "ICE_CANDIDATE": 2, + } +) + +func (x Message_Type) Enum() *Message_Type { + p := new(Message_Type) + *p = x + return p +} + +func (x Message_Type) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Message_Type) Descriptor() protoreflect.EnumDescriptor { + return file_msg_proto_enumTypes[0].Descriptor() +} + +func (Message_Type) Type() protoreflect.EnumType { + return &file_msg_proto_enumTypes[0] +} + +func (x Message_Type) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Message_Type.Descriptor instead. +func (Message_Type) EnumDescriptor() ([]byte, []int) { + return file_msg_proto_rawDescGZIP(), []int{0, 0} +} + +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Type *Message_Type `protobuf:"varint,1,opt,name=type,proto3,enum=libp2pwebrtcprivate.pb.Message_Type,oneof" json:"type,omitempty"` + Data *string `protobuf:"bytes,2,opt,name=data,proto3,oneof" json:"data,omitempty"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_msg_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_msg_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_msg_proto_rawDescGZIP(), []int{0} +} + +func (x *Message) GetType() Message_Type { + if x != nil && x.Type != nil { + return *x.Type + } + return Message_SDP_OFFER +} + +func (x *Message) GetData() string { + if x != nil && x.Data != nil { + return *x.Data + } + return "" +} + +var File_msg_proto protoreflect.FileDescriptor + +var file_msg_proto_rawDesc = []byte{ + 0x0a, 0x09, 0x6d, 0x73, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x16, 0x6c, 0x69, 0x62, + 0x70, 0x32, 0x70, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, + 0x2e, 0x70, 0x62, 0x22, 0xad, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, + 0x3d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x24, 0x2e, + 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x70, 0x72, 0x69, 0x76, + 0x61, 0x74, 0x65, 0x2e, 0x70, 0x62, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x54, + 0x79, 0x70, 0x65, 0x48, 0x00, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x17, + 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x88, 0x01, 0x01, 0x22, 0x38, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x0d, 0x0a, 0x09, 0x53, 0x44, 0x50, 0x5f, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0e, + 0x0a, 0x0a, 0x53, 0x44, 0x50, 0x5f, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x11, + 0x0a, 0x0d, 0x49, 0x43, 0x45, 0x5f, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, + 0x02, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x64, + 0x61, 0x74, 0x61, 0x42, 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, + 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, + 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x2f, 0x70, + 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_msg_proto_rawDescOnce sync.Once + file_msg_proto_rawDescData = file_msg_proto_rawDesc +) + +func file_msg_proto_rawDescGZIP() []byte { + file_msg_proto_rawDescOnce.Do(func() { + file_msg_proto_rawDescData = protoimpl.X.CompressGZIP(file_msg_proto_rawDescData) + }) + return file_msg_proto_rawDescData +} + +var file_msg_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_msg_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_msg_proto_goTypes = []interface{}{ + (Message_Type)(0), // 0: libp2pwebrtcprivate.pb.Message.Type + (*Message)(nil), // 1: libp2pwebrtcprivate.pb.Message +} +var file_msg_proto_depIdxs = []int32{ + 0, // 0: libp2pwebrtcprivate.pb.Message.type:type_name -> libp2pwebrtcprivate.pb.Message.Type + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_msg_proto_init() } +func file_msg_proto_init() { + if File_msg_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_msg_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_msg_proto_msgTypes[0].OneofWrappers = []interface{}{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_msg_proto_rawDesc, + NumEnums: 1, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_msg_proto_goTypes, + DependencyIndexes: file_msg_proto_depIdxs, + EnumInfos: file_msg_proto_enumTypes, + MessageInfos: file_msg_proto_msgTypes, + }.Build() + File_msg_proto = out.File + file_msg_proto_rawDesc = nil + file_msg_proto_goTypes = nil + file_msg_proto_depIdxs = nil +} diff --git a/p2p/transport/webrtcprivate/pb/msg.proto b/p2p/transport/webrtcprivate/pb/msg.proto new file mode 100644 index 0000000000..3674833ca2 --- /dev/null +++ b/p2p/transport/webrtcprivate/pb/msg.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package libp2pwebrtcprivate.pb; + +option go_package = "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb"; + +message Message { + // Specifies type in `data` field. + enum Type { + // String of `RTCSessionDescription.sdp` + SDP_OFFER = 0; + // String of `RTCSessionDescription.sdp` + SDP_ANSWER = 1; + // String of `RTCIceCandidate.toJSON()` + ICE_CANDIDATE = 2; + } + + optional Type type = 1; + optional string data = 2; +} \ No newline at end of file diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go new file mode 100644 index 0000000000..63d562863d --- /dev/null +++ b/p2p/transport/webrtcprivate/transport.go @@ -0,0 +1,501 @@ +package libp2pwebrtcprivate + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + pionlogger "github.com/pion/logging" + + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" + "github.com/libp2p/go-msgio/pbio" + "github.com/pion/webrtc/v3" + "go.uber.org/zap/zapcore" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +const ( + name = "webrtcprivate" + maxMsgSize = 4096 + connectTimeout = time.Minute + SignalingProtocol = "/webrtc-signaling" + disconnectedTimeout = 20 * time.Second + failedTimeout = 30 * time.Second + keepaliveTimeout = 15 * time.Second + maxAcceptQueueLen = 10 +) + +var ( + log = logging.Logger("webrtcprivate") + WebRTCAddr = ma.StringCast("/webrtc") +) + +type transport struct { + host host.Host + rcmgr network.ResourceManager + webrtcConfig webrtc.Configuration + gater connmgr.ConnectionGater + maxInFlightConnections int + + mu sync.Mutex + listener *listener +} + +var _ tpt.Transport = &transport{} + +func AddTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []webrtc.ICEServer) (*transport, error) { + n, ok := h.Network().(tpt.TransportNetwork) + if !ok { + return nil, fmt.Errorf("%v is not a transport network", h.Network()) + } + + t, err := newTransport(h, gater, stunServers) + if err != nil { + return nil, err + } + + if err := n.AddTransport(t); err != nil { + return nil, fmt.Errorf("failed to add transport to network: %w", err) + } + + if err := n.Listen(ma.StringCast("/webrtc")); err != nil { + return nil, err + } + + return t, nil +} + +func newTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []webrtc.ICEServer) (*transport, error) { + // We use elliptic P-256 since it is widely supported by browsers. + // + // Implementation note: Testing with the browser, + // it seems like Chromium only supports ECDSA P-256 or RSA key signatures in the webrtc TLS certificate. + // We tried using P-228 and P-384 which caused the DTLS handshake to fail with Illegal Parameter + // + // Please refer to this is a list of suggested algorithms for the WebCrypto API. + // The algorithm for generating a certificate for an RTCPeerConnection + // must adhere to the WebCrpyto API. From my observation, + // RSA and ECDSA P-256 is supported on almost all browsers. + // Ed25519 is not present on the list. + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate key for cert: %w", err) + } + cert, err := webrtc.GenerateCertificate(pk) + if err != nil { + return nil, fmt.Errorf("generate certificate: %w", err) + } + config := webrtc.Configuration{ + Certificates: []webrtc.Certificate{*cert}, + ICEServers: stunServers, + } + + return &transport{ + host: h, + rcmgr: h.Network().ResourceManager(), + webrtcConfig: config, + maxInFlightConnections: 16, + gater: gater, + }, nil +} + +// CanDial determines if we can dial to an address +func (t *transport) CanDial(addr ma.Multiaddr) bool { + circuit := false + webrtc := false + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CIRCUIT { + circuit = true + return true + } + // next element after p2p-circuit should be webrtc + if circuit { + webrtc = c.Protocol().Code == ma.P_WEBRTC + return false + } + return true + }) + return circuit && webrtc +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + // Connect to the peer on the circuit address + relayAddr := getRelayAddr(raddr) + // We drop the ForceDirectDial option as we need a relayed connection before we can + // setup a direct connection + ctx = network.WithoutForceDirectDial(ctx) + // We need this for the signaling stream + ctx = network.WithUseTransient(ctx, "webrtcprivate dial") + err := t.host.Connect(ctx, peer.AddrInfo{ID: p, Addrs: []ma.Multiaddr{relayAddr}}) + if err != nil { + return nil, fmt.Errorf("failed to open %s stream: %w", SignalingProtocol, err) + } + + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, true, raddr) + if err != nil { + log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) + return nil, err + } + if err := scope.SetPeer(p); err != nil { + return nil, err + } + + c, err := t.dialWithScope(ctx, p, scope) + if err != nil { + scope.Done() + log.Debug(err) + return nil, err + } + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + s, err := t.host.NewStream(ctx, p, SignalingProtocol) + if err != nil { + return nil, fmt.Errorf("error opening stream %s: %w", SignalingProtocol, err) + } + + if err := s.Scope().SetService(name); err != nil { + s.Reset() + return nil, fmt.Errorf("error attaching signaling stream to %s transport: %w", name, err) + } + + if err := s.Scope().ReserveMemory(2*maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + return nil, fmt.Errorf("error reserving memory for signaling stream: %w", err) + } + defer s.Scope().ReleaseMemory(maxMsgSize) + defer s.Close() + + deadline := time.Now().Add(connectTimeout) + if d, ok := ctx.Deadline(); ok && d.Before(deadline) { + deadline = d + } + s.SetDeadline(deadline) + + conn, err := t.setupConnection(ctx, s, scope) + if err != nil { + s.Reset() + return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) + } + + if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, conn) { + conn.Close() + s.Reset() + return nil, fmt.Errorf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) + } + return conn, nil +} + +func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + pc, err := t.NewPeerConnection() + if err != nil { + return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err) + } + + dataChannelQueue := libp2pwebrtc.SetupDataChannelQueue(pc, maxAcceptQueueLen) + + // register peerconnection state update callback + connectionState := make(chan webrtc.PeerConnectionState, 1) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateConnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: + // We only use the first state written to connectionState. + select { + case connectionState <- state: + default: + } + } + }) + + // register local ICE Candidate found callback + writeErr := make(chan error, 1) + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + // The callback can be called with a nil pointer + if candidate == nil { + return + } + b, err := json.Marshal(candidate.ToJSON()) + if err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("failed to marshal candidate to JSON: %w", err): + default: + } + return + } + data := string(b) + msg := pb.Message{ + Type: pb.Message_ICE_CANDIDATE.Enum(), + Data: &data, + } + if err = w.WriteMsg(&msg); err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("failed to write candidate: %w", err): + default: + } + } + }) + + // de-register candidate callback + defer pc.OnICECandidate(func(_ *webrtc.ICECandidate) {}) + + // We initialise a data channel otherwise the offer will have no ICE components + // https://stackoverflow.com/a/38872920/759687 + // We use out-of-band negotiation(negotiated=true), to ensure that this channel doesn't + // get accepted as a stream in AcceptStream on the remote side + negotiated := true + // Any value here is fine since this will be closed on connection establishment. We use 0 as + // it is also used for the /webrtc-direct handshake channel + var initStreamID uint16 + dc, err := pc.CreateDataChannel("init", &webrtc.DataChannelInit{Negotiated: &negotiated, ID: &initStreamID}) + if err != nil { + return nil, fmt.Errorf("failed to create data channel: %w", err) + } + defer dc.Close() + + // create an offer + offer, err := pc.CreateOffer(nil) + if err != nil { + return nil, fmt.Errorf("failed to create offer: %w", err) + } + msg := pb.Message{ + Type: pb.Message_SDP_OFFER.Enum(), + Data: &offer.SDP, + } + // send offer to peer + if err := w.WriteMsg(&msg); err != nil { + return nil, fmt.Errorf("failed to write to stream: %w", err) + } + if err := pc.SetLocalDescription(offer); err != nil { + return nil, fmt.Errorf("failed to set local description: %w", err) + } + + // read an incoming answer + if err := r.ReadMsg(&msg); err != nil { + return nil, fmt.Errorf("failed to read from stream: %w", err) + } + if msg.Type == nil || *msg.Type != pb.Message_SDP_ANSWER { + return nil, fmt.Errorf("invalid message: expected %s, got %s", pb.Message_SDP_ANSWER, msg.Type) + } + if msg.Data == nil || *msg.Data == "" { + return nil, fmt.Errorf("invalid message: empty answer") + } + answer := webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: *msg.Data, + } + if err := pc.SetRemoteDescription(answer); err != nil { + return nil, fmt.Errorf("failed to set remote description: %w", err) + } + + readErr := make(chan error, 1) + ctx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + // start a goroutine to read candidates + go func() { + for { + if ctx.Err() != nil { + return + } + + err := r.ReadMsg(&msg) + if err == io.EOF { + return + } + if err != nil { + readErr <- fmt.Errorf("read failed: %w", err) + return + } + if msg.Type == nil || *msg.Type != pb.Message_ICE_CANDIDATE { + readErr <- fmt.Errorf("invalid message: expected %s got %s", pb.Message_ICE_CANDIDATE, msg.Type) + return + } + // Ignore without erroring on empty message. + // Pion has a case where OnCandidate callback may be called with a nil + // candidate + if msg.Data == nil || *msg.Data == "" { + log.Debugf("received empty candidate from %s", s.Conn().RemotePeer()) + continue + } + + var init webrtc.ICECandidateInit + if err := json.Unmarshal([]byte(*msg.Data), &init); err != nil { + readErr <- fmt.Errorf("failed to unmarshal ice candidate %w", err) + return + } + if err := pc.AddICECandidate(init); err != nil { + readErr <- fmt.Errorf("failed to add ice candidate: %w", err) + return + } + } + }() + + select { + case <-ctx.Done(): + pc.Close() + return nil, ctx.Err() + case err := <-readErr: + pc.Close() + return nil, err + case err := <-writeErr: + pc.Close() + return nil, err + case state := <-connectionState: + switch state { + default: + pc.Close() + return nil, fmt.Errorf("conn establishment failed, state: %s", state) + case webrtc.PeerConnectionStateConnected: + } + } + localAddr, remoteAddr, err := getConnectionAddresses(pc) + if err != nil { + pc.Close() + return nil, fmt.Errorf("failed to get connection addresses: %w", err) + } + + conn, err := libp2pwebrtc.NewWebRTCConnection( + network.DirOutbound, + pc, + t, + scope, + t.host.ID(), + localAddr, + s.Conn().RemotePeer(), + t.host.Network().Peerstore().PubKey(s.Conn().RemotePeer()), // we have the pubkey from the relayed connection + remoteAddr, + dataChannelQueue, + ) + if err != nil { + pc.Close() + return nil, fmt.Errorf("failed to create tpt.CapableConn: %w", err) + } + return conn, nil +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + if _, err := laddr.ValueForProtocol(ma.P_WEBRTC); err != nil { + return nil, fmt.Errorf("invalid listen multiaddr: %s", laddr) + } + t.mu.Lock() + defer t.mu.Unlock() + if t.listener != nil { + return nil, errors.New("already listening on /webrtc") + } + ctx, cancel := context.WithCancel(context.Background()) + l := &listener{ + transport: t, + connQueue: make(chan tpt.CapableConn), + inflightQueue: make(chan struct{}, t.maxInFlightConnections), + ctx: ctx, + cancel: cancel, + } + t.listener = l + t.host.SetStreamHandler(SignalingProtocol, l.handleSignalingStream) + return l, nil +} + +func (t *transport) RemoveListener(l *listener) { + t.mu.Lock() + defer t.mu.Unlock() + if t.listener == l { + t.listener = nil + t.host.RemoveStreamHandler(SignalingProtocol) + } +} + +func (*transport) Protocols() []int { + return []int{ma.P_WEBRTC} +} + +func (*transport) Proxy() bool { + return false +} + +func (t *transport) NewPeerConnection() (*webrtc.PeerConnection, error) { + loggerFactory := pionlogger.NewDefaultLoggerFactory() + logLevel := pionlogger.LogLevelDisabled + switch log.Level() { + case zapcore.DebugLevel: + logLevel = pionlogger.LogLevelDebug + case zapcore.InfoLevel: + logLevel = pionlogger.LogLevelInfo + case zapcore.WarnLevel: + logLevel = pionlogger.LogLevelWarn + case zapcore.ErrorLevel: + logLevel = pionlogger.LogLevelError + } + loggerFactory.DefaultLogLevel = logLevel + s := webrtc.SettingEngine{LoggerFactory: loggerFactory} + s.SetICETimeouts(disconnectedTimeout, failedTimeout, keepaliveTimeout) + s.DetachDataChannels() + s.SetIncludeLoopbackCandidate(true) + api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) + return api.NewPeerConnection(t.webrtcConfig) +} + +// getRelayAddr removes /webrtc from addr and returns a circuit v2 only address +func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { + first, rest := ma.SplitFunc(addr, func(c ma.Component) bool { + return c.Protocol().Code == ma.P_WEBRTC + }) + // remove /webrtc prefix + _, rest = ma.SplitFirst(rest) + if rest == nil { + return first + } + return first.Encapsulate(rest) +} + +// getConnectionAddresses provides multiaddresses on the two sides of the connection pc +func getConnectionAddresses(pc *webrtc.PeerConnection) (local ma.Multiaddr, remote ma.Multiaddr, err error) { + if pc.SCTP() == nil { + return nil, nil, errors.New("no sctp transport") + } + if pc.SCTP().Transport() == nil { + return nil, nil, errors.New("no dtls transport") + } + if pc.SCTP().Transport().ICETransport() == nil { + return nil, nil, errors.New("no ice transport") + } + cp, err := pc.SCTP().Transport().ICETransport().GetSelectedCandidatePair() + if cp == nil || err != nil { + return nil, nil, fmt.Errorf("invalid candidate pair %s: %w", cp, err) + } + + localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) + if err != nil { + return nil, nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) + } + localAddr = localAddr.Encapsulate(WebRTCAddr) + + remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) + if err != nil { + return nil, nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) + } + remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) + + return localAddr, remoteAddr, nil +} diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go new file mode 100644 index 0000000000..f9d889cee2 --- /dev/null +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -0,0 +1,632 @@ +package libp2pwebrtcprivate + +import ( + "context" + "fmt" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// relayedHost is a webrtc enabled host with a relay reservation +type relayedHost struct { + webrtcHost + // R is the relay host + R host.Host + // Addr is the reachable /webrtc address + Addr ma.Multiaddr +} + +func (r *relayedHost) Close() { + r.R.Close() + r.webrtcHost.Close() +} + +type webrtcHost struct { + host.Host + // T is the webrtc transport used by the host + T *transport +} + +func newWebRTCHost(t *testing.T) *webrtcHost { + as := swarmt.GenSwarm(t) + a := blankhost.NewBlankHost(as) + upg := swarmt.GenUpgrader(t, as, nil) + err := client.AddTransport(a, upg) + require.NoError(t, err) + ta, err := newTransport(a, nil, nil) + require.NoError(t, err) + return &webrtcHost{ + Host: a, + T: ta, + } +} + +func newRelayedHost(t *testing.T) *relayedHost { + rh := blankhost.NewBlankHost(swarmt.GenSwarm(t)) + rr := relay.DefaultResources() + rr.MaxCircuits = 100 + _, err := relay.New(rh, relay.WithResources(rr)) + require.NoError(t, err) + + ps := swarmt.GenSwarm(t) + p := blankhost.NewBlankHost(ps) + upg := swarmt.GenUpgrader(t, ps, nil) + client.AddTransport(p, upg) + _, err = client.Reserve(context.Background(), p, peer.AddrInfo{ID: rh.ID(), Addrs: rh.Addrs()}) + require.NoError(t, err) + tp, err := newTransport(p, nil, nil) + require.NoError(t, err) + return &relayedHost{ + webrtcHost: webrtcHost{ + Host: p, + T: tp, + }, + R: rh, + Addr: ma.StringCast(fmt.Sprintf("%s/p2p/%s/p2p-circuit/webrtc/", rh.Addrs()[0], rh.ID())), + } +} + +func TestSingleDial(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + defer b.Close() + defer a.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ca, err := a.T.Dial(ctx, b.Addr, b.ID()) + require.NoError(t, err) + + cb, err := l.Accept() + require.NoError(t, err) + sa, err := ca.OpenStream(ctx) + require.NoError(t, err) + sb, err := cb.AcceptStream() + require.NoError(t, err) + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sb.Read(recv) + require.NoError(t, err) + require.Equal(t, "hello world", string(recv[:n])) + + ca.Close() + cb.Close() +} + +func TestConnectionProperties(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + defer b.Close() + defer a.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ca, err := a.T.Dial(ctx, b.Addr, b.ID()) + require.NoError(t, err) + + cb, err := l.Accept() + require.NoError(t, err) + + t.Run("Addresses", func(t *testing.T) { + testAddr := func(addr ma.Multiaddr) { + _, err := addr.ValueForProtocol(ma.P_UDP) + require.NoError(t, err) + _, err = addr.ValueForProtocol(ma.P_WEBRTC) + require.NoError(t, err) + } + testAddr(ca.LocalMultiaddr()) + testAddr(ca.RemoteMultiaddr()) + testAddr(cb.LocalMultiaddr()) + testAddr(cb.RemoteMultiaddr()) + }) + + t.Run("ConnectionState", func(t *testing.T) { + require.Equal(t, network.ConnectionState{Transport: "webrtc"}, ca.ConnState()) + require.Equal(t, network.ConnectionState{Transport: "webrtc"}, cb.ConnState()) + }) + +} + +func TestMultipleDials(t *testing.T) { + a := newWebRTCHost(t) + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + b := newRelayedHost(t) + defer b.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + if !assert.NoError(t, err) { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + ca, err := a.T.Dial(ctx, b.Addr, b.ID()) + if !assert.NoError(t, err) { + return + } + + cb, err := l.Accept() + if !assert.NoError(t, err) { + return + } + + sa, err := ca.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + sb, err := cb.AcceptStream() + if !assert.NoError(t, err) { + return + } + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sb.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + wg.Done() + }() + } + wg.Wait() +} + +func TestMultipleDialsAndListeners(t *testing.T) { + const N = 5 + var hosts []*relayedHost + for i := 0; i < N; i++ { + hosts = append(hosts, newRelayedHost(t)) + l, err := hosts[i].T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + defer hosts[i].Close() + defer l.Close() + } + var wg sync.WaitGroup + + dialAndPing := func(h *relayedHost, raddr ma.Multiaddr, p peer.ID) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + ca, err := h.T.Dial(ctx, raddr, p) + if !assert.NoError(t, err) { + return + } + defer ca.Close() + sa, err := ca.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + defer sa.Close() + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sa.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + } + + acceptAndPong := func(r *relayedHost) { + cb, err := r.T.listener.Accept() + if !assert.NoError(t, err) { + return + } + + sb, err := cb.AcceptStream() + if !assert.NoError(t, err) { + return + } + defer sb.Close() + + recv := make([]byte, 24) + n, err := sb.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + sb.Write(recv[:n]) + } + + for i := 0; i < N; i++ { + for j := i + 1; j < N; j++ { + wg.Add(1) + go func(i, j int) { + go acceptAndPong(hosts[j]) + dialAndPing(hosts[i], hosts[j].Addr, hosts[j].ID()) + wg.Done() + }(i, j) + } + } + wg.Wait() +} + +func TestDialerCanCreateStreams(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + listener, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + aC := make(chan bool) + go func() { + defer close(aC) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := a.T.Dial(ctx, b.Addr, b.ID()) + if !assert.NoError(t, err) { + return + } + s, err := conn.AcceptStream() + if !assert.NoError(t, err) { + return + } + recv := make([]byte, 24) + n, err := s.Read(recv) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(recv[:n]) + if !assert.NoError(t, err) { + return + } + s.Close() + }() + + bC := make(chan bool) + go func() { + defer close(bC) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + s, err := conn.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + defer s.Close() + + _, err = s.Write([]byte("hello world")) + if !assert.NoError(t, err) { + return + } + + recv := make([]byte, 24) + n, err := s.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + }() + + select { + case <-aC: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + select { + case <-bC: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } +} + +func TestDialerCanCreateStreamsMultiple(t *testing.T) { + count := 5 + a := newWebRTCHost(t) + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, a.ID(), lconn.RemotePeer()) { + return + } + var wg sync.WaitGroup + + for i := 0; i < count; i++ { + stream, err := lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + n, err := stream.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "test", string(buf[:n])) { + return + } + _, err = stream.Write([]byte("test")) + if !assert.NoError(t, err) { + return + } + }() + } + + wg.Wait() + done <- struct{}{} + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + + for i := 0; i < count; i++ { + idx := i + go func() { + stream, err := conn.OpenStream(context.Background()) + if !assert.NoError(t, err) { + return + } + t.Logf("dialer opened stream: %d", idx) + buf := make([]byte, 100) + _, err = stream.Write([]byte("test")) + if !assert.NoError(t, err) { + return + } + n, err := stream.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "test", string(buf[:n])) { + return + } + }() + } + select { + case <-done: + case <-time.After(20 * time.Second): + t.Fatal("timed out") + } +} + +func TestMaxInflightQueue(t *testing.T) { + b := newRelayedHost(t) + defer b.Close() + count := 3 + b.T.maxInFlightConnections = count + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + defer listener.Close() + + var success, failure atomic.Int32 + var wg sync.WaitGroup + for i := 0; i < count+1; i++ { + wg.Add(1) + go func() { + a := newWebRTCHost(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := a.T.Dial(ctx, b.Addr, b.ID()) + if err == nil { + success.Add(1) + } else { + failure.Add(1) + } + wg.Done() + }() + } + wg.Wait() + require.Equal(t, 1, int(failure.Load())) + require.Equal(t, count, int(success.Load())) +} + +func TestRemoteReadsAfterClose(t *testing.T) { + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + + a := newWebRTCHost(t) + + done := make(chan error) + go func() { + lconn, err := listener.Accept() + if err != nil { + done <- err + return + } + stream, err := lconn.AcceptStream() + if err != nil { + done <- err + return + } + _, err = stream.Write([]byte{1, 2, 3, 4}) + if err != nil { + done <- err + return + } + err = stream.Close() + if err != nil { + done <- err + return + } + close(done) + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + // create a stream + stream, err := conn.OpenStream(context.Background()) + + require.NoError(t, err) + // require write and close to complete + require.NoError(t, <-done) + + stream.SetReadDeadline(time.Now().Add(5 * time.Second)) + + buf := make([]byte, 10) + n, err := stream.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 4) +} + +func TestStreamDeadline(t *testing.T) { + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + a := newWebRTCHost(t) + + t.Run("SetReadDeadline", func(t *testing.T) { + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + _, err = lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + + // deadline set to the past + stream.SetReadDeadline(time.Now().Add(-200 * time.Millisecond)) + _, err = stream.Read([]byte{0, 0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + + // future deadline exceeded + stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _, err = stream.Read([]byte{0, 0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + _, err = lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + + stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + time.Sleep(201 * time.Millisecond) + largeBuffer := make([]byte, 2*1024*1024) + _, err = stream.Write(largeBuffer) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) +} + +func TestCanDial(t *testing.T) { + a := newWebRTCHost(t) + defer a.Close() + b := newWebRTCHost(t) + + tests := []struct { + addr ma.Multiaddr + canDial bool + }{ + { + addr: ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1234/p2p/%s/p2p-circuit/", b.ID())), + canDial: false, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1234/p2p/%s/p2p-circuit/webrtc", b.ID())), + canDial: true, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic-v1/p2p/%s/p2p-circuit/", b.ID())), + canDial: false, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic-v1/p2p/%s/p2p-circuit/webrtc/", b.ID())), + canDial: true, + }, + { + addr: ma.StringCast("/ip4/1.2.3.4/tcp/1234/webrtc"), + canDial: false, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1234/p2p/%s/webrtc/", b.ID())), + canDial: false, + }, + { + addr: ma.StringCast("/ip4/1.2.3.4/tcp/1234/"), + canDial: false, + }, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + require.Equal(t, tt.canDial, a.T.CanDial(tt.addr), "args: %s", tt.addr) + }) + } +} + +func TestCanListenTwice(t *testing.T) { + b := newRelayedHost(t) + defer b.Close() + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + a := newWebRTCHost(t) + defer a.Close() + + ca, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + cb, err := listener.Accept() + require.NoError(t, err) + ca.Close() + cb.Close() + listener.Close() + _, err = listener.Accept() + require.Error(t, err) + + listener, err = b.T.Listen(WebRTCAddr) + require.NoError(t, err) + ca, err = a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + cb, err = listener.Accept() + require.NoError(t, err) + ca.Close() + cb.Close() +}