diff --git a/p2p/transport/websocket/addrs.go b/p2p/transport/websocket/addrs.go index e789399da3..608eb2d0da 100644 --- a/p2p/transport/websocket/addrs.go +++ b/p2p/transport/websocket/addrs.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/url" + "strconv" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -21,22 +22,36 @@ func (addr *Addr) Network() string { return "websocket" } -// NewAddr creates a new Addr using the given host string +// NewAddr creates an Addr with `ws` scheme (insecure). +// +// Deprecated. Use NewAddrWithScheme. func NewAddr(host string) *Addr { + // Older versions of the transport only supported insecure connections (i.e. + // WS instead of WSS). Assume that is the case here. + return NewAddrWithScheme(host, false) +} + +// NewAddrWithScheme creates a new Addr using the given host string. isSecure +// should be true for WSS connections and false for WS. +func NewAddrWithScheme(host string, isSecure bool) *Addr { + scheme := "ws" + if isSecure { + scheme = "wss" + } return &Addr{ URL: &url.URL{ - Host: host, + Scheme: scheme, + Host: host, }, } } func ConvertWebsocketMultiaddrToNetAddr(maddr ma.Multiaddr) (net.Addr, error) { - _, host, err := manet.DialArgs(maddr) + url, err := parseMultiaddr(maddr) if err != nil { return nil, err } - - return NewAddr(host), nil + return &Addr{URL: url}, nil } func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) { @@ -45,17 +60,43 @@ func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) { return nil, fmt.Errorf("not a websocket address") } - tcpaddr, err := net.ResolveTCPAddr("tcp", wsa.Host) - if err != nil { - return nil, err + var ( + tcpma ma.Multiaddr + err error + port int + host = wsa.Hostname() + ) + + // Get the port + if portStr := wsa.Port(); portStr != "" { + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("failed to parse port '%q': %s", portStr, err) + } + } else { + return nil, fmt.Errorf("invalid port in url: '%q'", wsa.URL) } - tcpma, err := manet.FromNetAddr(tcpaddr) - if err != nil { - return nil, err + // NOTE: Ignoring IPv6 zones... + // Detect if host is IP address or DNS + if ip := net.ParseIP(host); ip != nil { + // Assume IP address + tcpma, err = manet.FromNetAddr(&net.TCPAddr{ + IP: ip, + Port: port, + }) + if err != nil { + return nil, err + } + } else { + // Assume DNS name + tcpma, err = ma.NewMultiaddr(fmt.Sprintf("/dns/%s/tcp/%d", host, port)) + if err != nil { + return nil, err + } } - wsma, err := ma.NewMultiaddr("/ws") + wsma, err := ma.NewMultiaddr("/" + wsa.Scheme) if err != nil { return nil, err } @@ -63,11 +104,34 @@ func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) { return tcpma.Encapsulate(wsma), nil } -func parseMultiaddr(a ma.Multiaddr) (string, error) { - _, host, err := manet.DialArgs(a) - if err != nil { - return "", err +func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) { + // Only look at the _last_ component. + maddr, wscomponent := ma.SplitLast(maddr) + if maddr == nil || wscomponent == nil { + return nil, fmt.Errorf("websocket addrs need at least two components") + } + + var scheme string + switch wscomponent.Protocol().Code { + case ma.P_WS: + scheme = "ws" + case ma.P_WSS: + scheme = "wss" + default: + return nil, fmt.Errorf("not a websocket multiaddr") } - return "ws://" + host, nil + network, host, err := manet.DialArgs(maddr) + if err != nil { + return nil, err + } + switch network { + case "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("unsupported websocket network %s", network) + } + return &url.URL{ + Scheme: scheme, + Host: host, + }, nil } diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index d962760088..1a73c28762 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -17,7 +17,7 @@ func TestMultiaddrParsing(t *testing.T) { if err != nil { t.Fatal(err) } - if wsaddr != "ws://127.0.0.1:5555" { + if wsaddr.String() != "ws://127.0.0.1:5555" { t.Fatalf("expected ws://127.0.0.1:5555, got %s", wsaddr) } } @@ -37,7 +37,7 @@ func TestParseWebsocketNetAddr(t *testing.T) { t.Fatalf("expect \"not a websocket address\", got \"%s\"", err) } - wsAddr := NewAddr("127.0.0.1:5555") + wsAddr := NewAddrWithScheme("127.0.0.1:5555", false) parsed, err := ParseWebsocketNetAddr(wsAddr) if err != nil { t.Fatal(err) @@ -58,8 +58,8 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { if err != nil { t.Fatal(err) } - if wsaddr.String() != "//127.0.0.1:5555" { - t.Fatalf("expected //127.0.0.1:5555, got %s", wsaddr) + if wsaddr.String() != "ws://127.0.0.1:5555" { + t.Fatalf("expected ws://127.0.0.1:5555, got %s", wsaddr) } if wsaddr.Network() != "websocket" { t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network()) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 5f520a1e76..6f2e0a7667 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -16,6 +16,7 @@ var GracefulCloseTimeout = 100 * time.Millisecond // Conn implements net.Conn interface for gorilla/websocket. type Conn struct { *ws.Conn + secure bool DefaultMessageType int reader io.Reader closeOnce sync.Once @@ -25,6 +26,15 @@ type Conn struct { var _ net.Conn = (*Conn)(nil) +// NewConn creates a Conn given a regular gorilla/websocket Conn. +func NewConn(raw *ws.Conn, secure bool) *Conn { + return &Conn{ + Conn: raw, + secure: secure, + DefaultMessageType: ws.BinaryMessage, + } +} + func (c *Conn) Read(b []byte) (int, error) { c.readLock.Lock() defer c.readLock.Unlock() @@ -109,11 +119,11 @@ func (c *Conn) Close() error { } func (c *Conn) LocalAddr() net.Addr { - return NewAddr(c.Conn.LocalAddr().String()) + return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure) } func (c *Conn) RemoteAddr() net.Addr { - return NewAddr(c.Conn.RemoteAddr().String()) + return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure) } func (c *Conn) SetDeadline(t time.Time) error { @@ -139,11 +149,3 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } - -// NewConn creates a Conn given a regular gorilla/websocket Conn. -func NewConn(raw *ws.Conn) *Conn { - return &Conn{ - Conn: raw, - DefaultMessageType: ws.BinaryMessage, - } -} diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index a1f3fd465e..d2cf5403de 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -11,7 +11,6 @@ import ( type listener struct { net.Listener - laddr ma.Multiaddr closed chan struct{} @@ -31,7 +30,7 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { } select { - case l.incoming <- NewConn(c): + case l.incoming <- NewConn(c, false): case <-l.closed: c.Close() } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 5f3d69594b..77d1f65905 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -3,6 +3,7 @@ package websocket import ( "context" + "crypto/tls" "net" "net/http" "net/url" @@ -22,11 +23,12 @@ var WsFmt = mafmt.And(mafmt.TCP, mafmt.Base(ma.P_WS)) // This is _not_ WsFmt because we want the transport to stick to dialing fully // resolved addresses. -var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP), mafmt.Base(ma.P_WS)) +var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP), mafmt.Or(mafmt.Base(ma.P_WS), mafmt.Base(ma.P_WSS))) func init() { manet.RegisterFromNetAddr(ParseWebsocketNetAddr, "websocket") manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "ws") + manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss") } // Default gorilla upgrader @@ -37,22 +39,44 @@ var upgrader = ws.Upgrader{ }, } +type Option func(*WebsocketTransport) error + +// WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only +// relevant for non-browser usages. +// +// Some useful use cases include setting InsecureSkipVerify to `true`, or +// setting user-defined trusted CA certificates. +func WithTLSClientConfig(c *tls.Config) Option { + return func(t *WebsocketTransport) error { + t.tlsClientConf = c + return nil + } +} + // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader rcmgr network.ResourceManager + + tlsClientConf *tls.Config } var _ transport.Transport = (*WebsocketTransport)(nil) -func New(u transport.Upgrader, rcmgr network.ResourceManager) *WebsocketTransport { +func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) { if rcmgr == nil { rcmgr = network.NullResourceManager } - return &WebsocketTransport{ + t := &WebsocketTransport{ upgrader: u, rcmgr: rcmgr, } + for _, opt := range opts { + if err := opt(t); err != nil { + return nil, err + } + } + return t, nil } func (t *WebsocketTransport) CanDial(a ma.Multiaddr) bool { @@ -60,7 +84,7 @@ func (t *WebsocketTransport) CanDial(a ma.Multiaddr) bool { } func (t *WebsocketTransport) Protocols() []int { - return []int{ma.ProtocolWithCode(ma.P_WS).Code} + return []int{ma.P_WS, ma.P_WSS} } func (t *WebsocketTransport) Proxy() bool { @@ -86,12 +110,12 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } - wscon, _, err := ws.DefaultDialer.Dial(wsurl, nil) + wscon, _, err := ws.DefaultDialer.Dial(wsurl.String(), nil) if err != nil { return nil, err } - mnc, err := manet.WrapNetConn(NewConn(wscon)) + mnc, err := manet.WrapNetConn(NewConn(wscon, wsurl.Scheme == "wss")) if err != nil { wscon.Close() return nil, err diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index e83e3a859b..addd56d737 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -3,85 +3,115 @@ package websocket import ( "bytes" "context" + "crypto/tls" "io" "io/ioutil" + "net" "testing" - "testing/iotest" - csms "github.com/libp2p/go-conn-security-multistream" "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" "github.com/libp2p/go-libp2p-core/sec/insecure" "github.com/libp2p/go-libp2p-core/test" + "github.com/libp2p/go-libp2p-core/transport" + csms "github.com/libp2p/go-conn-security-multistream" mplex "github.com/libp2p/go-libp2p-mplex" ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" tptu "github.com/libp2p/go-libp2p-transport-upgrader" ma "github.com/multiformats/go-multiaddr" ) -//lint:ignore U1000 // see https://github.com/dominikh/go-tools/issues/633 -func newSecureMuxer(t *testing.T, id peer.ID) sec.SecureMuxer { +func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { + t.Helper() + id, m := newSecureMuxer(t) + u, err := tptu.New(m, new(mplex.Transport)) + if err != nil { + t.Fatal(err) + } + return id, u +} + +func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Helper() priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) if err != nil { t.Fatal(err) } + id, err := peer.IDFromPrivateKey(priv) + if err != nil { + t.Fatal(err) + } var secMuxer csms.SSMuxer secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv)) - return &secMuxer + return id, &secMuxer } func TestCanDial(t *testing.T) { - addrWs, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555/ws") + d := &WebsocketTransport{} + if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/ws")) { + t.Fatal("expected to match websocket maddr, but did not") + } + if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/wss")) { + t.Fatal("expected to match secure websocket maddr, but did not") + } + if d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555")) { + t.Fatal("expected to not match tcp maddr, but did") + } +} + +func TestDialWss(t *testing.T) { + if _, err := net.LookupIP("nyc-1.bootstrap.libp2p.io"); err != nil { + t.Skip("this test requries an internet connection and it seems like we currently don't have one") + } + raddr := ma.StringCast("/dns4/nyc-1.bootstrap.libp2p.io/tcp/443/wss") + rid, err := peer.Decode("QmSoLueR4xBeUbY9WZ9xGUUxunbKWcrNFTDAadQJmocnWm") if err != nil { t.Fatal(err) } - addrTCP, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555") + tlsConfig := &tls.Config{InsecureSkipVerify: true} + _, u := newUpgrader(t) + tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig)) if err != nil { t.Fatal(err) } - - d := &WebsocketTransport{} - matchTrue := d.CanDial(addrWs) - matchFalse := d.CanDial(addrTCP) - - if !matchTrue { - t.Fatal("expected to match websocket maddr, but did not") + conn, err := tpt.Dial(context.Background(), raddr, rid) + if err != nil { + t.Fatal(err) } - - if matchFalse { - t.Fatal("expected to not match tcp maddr, but did") + stream, err := conn.OpenStream(context.Background()) + if err != nil { + t.Fatal(err) } + defer stream.Close() } func TestWebsocketTransport(t *testing.T) { t.Skip("This test is failing, see https://github.com/libp2p/go-ws-transport/issues/99") - ua, err := tptu.New(newSecureMuxer(t, "peerA"), new(mplex.Transport)) + _, ua := newUpgrader(t) + ta, err := New(ua, nil) if err != nil { t.Fatal(err) } - ta := New(ua, nil) - ub, err := tptu.New(newSecureMuxer(t, "peerB"), new(mplex.Transport)) + _, ub := newUpgrader(t) + tb, err := New(ub, nil) if err != nil { t.Fatal(err) } - tb := New(ub, nil) - zero := "/ip4/127.0.0.1/tcp/0/ws" - ttransport.SubtestTransport(t, ta, tb, zero, "peerA") + ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", "peerA") } func TestWebsocketListen(t *testing.T) { - zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws") + id, u := newUpgrader(t) + tpt, err := New(u, network.NullResourceManager) if err != nil { t.Fatal(err) } - - tpt := &WebsocketTransport{} - l, err := tpt.maListen(zero) + l, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { t.Fatal(err) } @@ -90,19 +120,20 @@ func TestWebsocketListen(t *testing.T) { msg := []byte("HELLO WORLD") go func() { - c, err := tpt.maDial(context.Background(), l.Multiaddr()) + c, err := tpt.Dial(context.Background(), l.Multiaddr(), id) if err != nil { t.Error(err) return } - - _, err = c.Write(msg) + str, err := c.OpenStream(context.Background()) if err != nil { t.Error(err) } - err = c.Close() - if err != nil { + defer str.Close() + + if _, err = str.Write(msg); err != nil { t.Error(err) + return } }() @@ -111,10 +142,13 @@ func TestWebsocketListen(t *testing.T) { t.Fatal(err) } defer c.Close() + str, err := c.AcceptStream() + if err != nil { + t.Fatal(err) + } + defer str.Close() - obr := iotest.OneByteReader(c) - - out, err := ioutil.ReadAll(obr) + out, err := ioutil.ReadAll(str) if err != nil { t.Fatal(err) } @@ -125,13 +159,12 @@ func TestWebsocketListen(t *testing.T) { } func TestConcurrentClose(t *testing.T) { - zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws") + _, u := newUpgrader(t) + tpt, err := New(u, network.NullResourceManager) if err != nil { t.Fatal(err) } - - tpt := &WebsocketTransport{} - l, err := tpt.maListen(zero) + l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { t.Fatal(err) } @@ -166,13 +199,12 @@ func TestConcurrentClose(t *testing.T) { } func TestWriteZero(t *testing.T) { - zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws") + _, u := newUpgrader(t) + tpt, err := New(u, network.NullResourceManager) if err != nil { t.Fatal(err) } - - tpt := &WebsocketTransport{} - l, err := tpt.maListen(zero) + l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { t.Fatal(err) }