Skip to content

Commit

Permalink
Merge pull request #115 from libp2p/feat/wss-dialing
Browse files Browse the repository at this point in the history
add support for wss dialing
  • Loading branch information
marten-seemann authored Feb 17, 2022
2 parents 2d53c8f + a55402d commit dd44db2
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 82 deletions.
98 changes: 81 additions & 17 deletions p2p/transport/websocket/addrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/url"
"strconv"

ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
Expand All @@ -21,22 +22,36 @@ func (addr *Addr) Network() string {
return "websocket"
}

// NewAddr creates a new Addr using the given host string
// NewAddr creates an Addr with `ws` scheme (insecure).
//
// Deprecated. Use NewAddrWithScheme.
func NewAddr(host string) *Addr {
// Older versions of the transport only supported insecure connections (i.e.
// WS instead of WSS). Assume that is the case here.
return NewAddrWithScheme(host, false)
}

// NewAddrWithScheme creates a new Addr using the given host string. isSecure
// should be true for WSS connections and false for WS.
func NewAddrWithScheme(host string, isSecure bool) *Addr {
scheme := "ws"
if isSecure {
scheme = "wss"
}
return &Addr{
URL: &url.URL{
Host: host,
Scheme: scheme,
Host: host,
},
}
}

func ConvertWebsocketMultiaddrToNetAddr(maddr ma.Multiaddr) (net.Addr, error) {
_, host, err := manet.DialArgs(maddr)
url, err := parseMultiaddr(maddr)
if err != nil {
return nil, err
}

return NewAddr(host), nil
return &Addr{URL: url}, nil
}

func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) {
Expand All @@ -45,29 +60,78 @@ func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) {
return nil, fmt.Errorf("not a websocket address")
}

tcpaddr, err := net.ResolveTCPAddr("tcp", wsa.Host)
if err != nil {
return nil, err
var (
tcpma ma.Multiaddr
err error
port int
host = wsa.Hostname()
)

// Get the port
if portStr := wsa.Port(); portStr != "" {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("failed to parse port '%q': %s", portStr, err)
}
} else {
return nil, fmt.Errorf("invalid port in url: '%q'", wsa.URL)
}

tcpma, err := manet.FromNetAddr(tcpaddr)
if err != nil {
return nil, err
// NOTE: Ignoring IPv6 zones...
// Detect if host is IP address or DNS
if ip := net.ParseIP(host); ip != nil {
// Assume IP address
tcpma, err = manet.FromNetAddr(&net.TCPAddr{
IP: ip,
Port: port,
})
if err != nil {
return nil, err
}
} else {
// Assume DNS name
tcpma, err = ma.NewMultiaddr(fmt.Sprintf("/dns/%s/tcp/%d", host, port))
if err != nil {
return nil, err
}
}

wsma, err := ma.NewMultiaddr("/ws")
wsma, err := ma.NewMultiaddr("/" + wsa.Scheme)
if err != nil {
return nil, err
}

return tcpma.Encapsulate(wsma), nil
}

func parseMultiaddr(a ma.Multiaddr) (string, error) {
_, host, err := manet.DialArgs(a)
if err != nil {
return "", err
func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) {
// Only look at the _last_ component.
maddr, wscomponent := ma.SplitLast(maddr)
if maddr == nil || wscomponent == nil {
return nil, fmt.Errorf("websocket addrs need at least two components")
}

var scheme string
switch wscomponent.Protocol().Code {
case ma.P_WS:
scheme = "ws"
case ma.P_WSS:
scheme = "wss"
default:
return nil, fmt.Errorf("not a websocket multiaddr")
}

return "ws://" + host, nil
network, host, err := manet.DialArgs(maddr)
if err != nil {
return nil, err
}
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("unsupported websocket network %s", network)
}
return &url.URL{
Scheme: scheme,
Host: host,
}, nil
}
8 changes: 4 additions & 4 deletions p2p/transport/websocket/addrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestMultiaddrParsing(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if wsaddr != "ws://127.0.0.1:5555" {
if wsaddr.String() != "ws://127.0.0.1:5555" {
t.Fatalf("expected ws://127.0.0.1:5555, got %s", wsaddr)
}
}
Expand All @@ -37,7 +37,7 @@ func TestParseWebsocketNetAddr(t *testing.T) {
t.Fatalf("expect \"not a websocket address\", got \"%s\"", err)
}

wsAddr := NewAddr("127.0.0.1:5555")
wsAddr := NewAddrWithScheme("127.0.0.1:5555", false)
parsed, err := ParseWebsocketNetAddr(wsAddr)
if err != nil {
t.Fatal(err)
Expand All @@ -58,8 +58,8 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if wsaddr.String() != "//127.0.0.1:5555" {
t.Fatalf("expected //127.0.0.1:5555, got %s", wsaddr)
if wsaddr.String() != "ws://127.0.0.1:5555" {
t.Fatalf("expected ws://127.0.0.1:5555, got %s", wsaddr)
}
if wsaddr.Network() != "websocket" {
t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network())
Expand Down
22 changes: 12 additions & 10 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var GracefulCloseTimeout = 100 * time.Millisecond
// Conn implements net.Conn interface for gorilla/websocket.
type Conn struct {
*ws.Conn
secure bool
DefaultMessageType int
reader io.Reader
closeOnce sync.Once
Expand All @@ -25,6 +26,15 @@ type Conn struct {

var _ net.Conn = (*Conn)(nil)

// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewConn(raw *ws.Conn, secure bool) *Conn {
return &Conn{
Conn: raw,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
}
}

func (c *Conn) Read(b []byte) (int, error) {
c.readLock.Lock()
defer c.readLock.Unlock()
Expand Down Expand Up @@ -109,11 +119,11 @@ func (c *Conn) Close() error {
}

func (c *Conn) LocalAddr() net.Addr {
return NewAddr(c.Conn.LocalAddr().String())
return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure)
}

func (c *Conn) RemoteAddr() net.Addr {
return NewAddr(c.Conn.RemoteAddr().String())
return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure)
}

func (c *Conn) SetDeadline(t time.Time) error {
Expand All @@ -139,11 +149,3 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {

return c.Conn.SetWriteDeadline(t)
}

// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewConn(raw *ws.Conn) *Conn {
return &Conn{
Conn: raw,
DefaultMessageType: ws.BinaryMessage,
}
}
3 changes: 1 addition & 2 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

type listener struct {
net.Listener

laddr ma.Multiaddr

closed chan struct{}
Expand All @@ -31,7 +30,7 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

select {
case l.incoming <- NewConn(c):
case l.incoming <- NewConn(c, false):
case <-l.closed:
c.Close()
}
Expand Down
36 changes: 30 additions & 6 deletions p2p/transport/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package websocket

import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
Expand All @@ -22,11 +23,12 @@ var WsFmt = mafmt.And(mafmt.TCP, mafmt.Base(ma.P_WS))

// This is _not_ WsFmt because we want the transport to stick to dialing fully
// resolved addresses.
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP), mafmt.Base(ma.P_WS))
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP), mafmt.Or(mafmt.Base(ma.P_WS), mafmt.Base(ma.P_WSS)))

func init() {
manet.RegisterFromNetAddr(ParseWebsocketNetAddr, "websocket")
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "ws")
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss")
}

// Default gorilla upgrader
Expand All @@ -37,30 +39,52 @@ var upgrader = ws.Upgrader{
},
}

type Option func(*WebsocketTransport) error

// WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only
// relevant for non-browser usages.
//
// Some useful use cases include setting InsecureSkipVerify to `true`, or
// setting user-defined trusted CA certificates.
func WithTLSClientConfig(c *tls.Config) Option {
return func(t *WebsocketTransport) error {
t.tlsClientConf = c
return nil
}
}

// WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct {
upgrader transport.Upgrader
rcmgr network.ResourceManager

tlsClientConf *tls.Config
}

var _ transport.Transport = (*WebsocketTransport)(nil)

func New(u transport.Upgrader, rcmgr network.ResourceManager) *WebsocketTransport {
func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) {
if rcmgr == nil {
rcmgr = network.NullResourceManager
}
return &WebsocketTransport{
t := &WebsocketTransport{
upgrader: u,
rcmgr: rcmgr,
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}
return t, nil
}

func (t *WebsocketTransport) CanDial(a ma.Multiaddr) bool {
return dialMatcher.Matches(a)
}

func (t *WebsocketTransport) Protocols() []int {
return []int{ma.ProtocolWithCode(ma.P_WS).Code}
return []int{ma.P_WS, ma.P_WSS}
}

func (t *WebsocketTransport) Proxy() bool {
Expand All @@ -86,12 +110,12 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
return nil, err
}

wscon, _, err := ws.DefaultDialer.Dial(wsurl, nil)
wscon, _, err := ws.DefaultDialer.Dial(wsurl.String(), nil)
if err != nil {
return nil, err
}

mnc, err := manet.WrapNetConn(NewConn(wscon))
mnc, err := manet.WrapNetConn(NewConn(wscon, wsurl.Scheme == "wss"))
if err != nil {
wscon.Close()
return nil, err
Expand Down
Loading

0 comments on commit dd44db2

Please sign in to comment.