diff --git a/tcp.go b/tcp.go index 1483d67..6c20410 100644 --- a/tcp.go +++ b/tcp.go @@ -70,10 +70,12 @@ func NewTCPTransport(upgrader *tptu.Upgrader) *TcpTransport { return &TcpTransport{Upgrader: upgrader, ConnectTimeout: DefaultConnectTimeout} } +var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP)) + // CanDial returns true if this transport believes it can dial the given // multiaddr. func (t *TcpTransport) CanDial(addr ma.Multiaddr) bool { - return mafmt.TCP.Matches(addr) + return dialMatcher.Matches(addr) } func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { diff --git a/tcp_test.go b/tcp_test.go index a4a85bd..8dcad45 100644 --- a/tcp_test.go +++ b/tcp_test.go @@ -36,6 +36,27 @@ func TestTcpTransport(t *testing.T) { envReuseportVal = true } +func TestTcpTransportCantDialDNS(t *testing.T) { + for i := 0; i < 2; i++ { + dnsa, err := ma.NewMultiaddr("/dns4/example.com/tcp/1234") + if err != nil { + t.Fatal(err) + } + + tpt := NewTCPTransport(&tptu.Upgrader{ + Secure: makeInsecureTransport(t), + Muxer: new(mplex.Transport), + }) + + if tpt.CanDial(dnsa) { + t.Fatal("shouldn't be able to dial dns") + } + + envReuseportVal = false + } + envReuseportVal = true +} + func TestTcpTransportCantListenUtp(t *testing.T) { for i := 0; i < 2; i++ { utpa, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/utp")