diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go index 6ce76b4b7e..73f9b56a16 100644 --- a/p2p/transport/quicreuse/connmgr.go +++ b/p2p/transport/quicreuse/connmgr.go @@ -164,7 +164,11 @@ func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (re if err != nil { return nil, err } - return &singleOwnerTransport{tr: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil + tr := &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn} + if c.metricsTracer != nil { + tr.Transport.Tracer = c.metricsTracer + } + return tr, nil } func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) { @@ -193,7 +197,7 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf if err != nil { return nil, err } - conn, err := tr.Transport().Dial(ctx, naddr, tlsConf, quicConf) + conn, err := tr.Dial(ctx, naddr, tlsConf, quicConf) if err != nil { tr.DecreaseCount() return nil, err @@ -221,7 +225,12 @@ func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refC if err != nil { return nil, err } - return &singleOwnerTransport{tr: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil + tr := &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn} + if c.metricsTracer != nil { + tr.Transport.Tracer = c.metricsTracer + } + + return tr, nil } func (c *ConnManager) Protocols() []int { diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go index 81bb61f998..7c5aa5f16b 100644 --- a/p2p/transport/quicreuse/connmgr_test.go +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -107,8 +107,8 @@ func TestConnectionPassedToQUICForListening(t *testing.T) { require.NoError(t, err) quicTr, err := cm.transportForListen(netw, naddr) require.NoError(t, err) - defer quicTr.Transport().Close() - if _, ok := quicTr.Transport().Conn.(quic.OOBCapablePacketConn); !ok { + defer quicTr.Close() + if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok { t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") } } @@ -163,8 +163,8 @@ func TestConnectionPassedToQUICForDialing(t *testing.T) { quicTr, err := cm.TransportForDial(netw, naddr) require.NoError(t, err, "dial error") - defer quicTr.Transport().Conn.Close() - if _, ok := quicTr.Transport().Conn.(quic.OOBCapablePacketConn); !ok { + defer quicTr.Close() + if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok { t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") } } diff --git a/p2p/transport/quicreuse/listener.go b/p2p/transport/quicreuse/listener.go index abd2f6b741..42ac6217a9 100644 --- a/p2p/transport/quicreuse/listener.go +++ b/p2p/transport/quicreuse/listener.go @@ -76,7 +76,7 @@ func newConnListener(c refCountedQuicTransport, quicConfig *quic.Config, enableD } quicConf := quicConfig.Clone() quicConf.AllowConnectionWindowIncrease = cl.allowWindowIncrease - ln, err := c.Transport().Listen(tlsConf, quicConf) + ln, err := c.Listen(tlsConf, quicConf) if err != nil { return nil, err } diff --git a/p2p/transport/quicreuse/reuse.go b/p2p/transport/quicreuse/reuse.go index 47936216bc..da3cecb0bb 100644 --- a/p2p/transport/quicreuse/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -1,6 +1,8 @@ package quicreuse import ( + "context" + "crypto/tls" "net" "sync" "time" @@ -11,7 +13,6 @@ import ( ) type refCountedQuicTransport interface { - Transport() *quic.Transport WriteTo([]byte, net.Addr) (int, error) LocalAddr() net.Addr @@ -20,10 +21,13 @@ type refCountedQuicTransport interface { // count conn reference DecreaseCount() IncreaseCount() + + Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error) + Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error) } type singleOwnerTransport struct { - tr quic.Transport + quic.Transport // Used to write packets directly around QUIC. packetConn net.PacketConn @@ -31,15 +35,11 @@ type singleOwnerTransport struct { func (c *singleOwnerTransport) IncreaseCount() {} func (c *singleOwnerTransport) DecreaseCount() { - c.tr.Close() -} - -func (c *singleOwnerTransport) Transport() *quic.Transport { - return &c.tr + c.Transport.Close() } func (c *singleOwnerTransport) LocalAddr() net.Addr { - return c.tr.Conn.LocalAddr() + return c.Transport.Conn.LocalAddr() } func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) { @@ -47,10 +47,6 @@ func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) { return c.packetConn.WriteTo(b, addr) } -func (c *singleOwnerTransport) Close() error { - return c.tr.Close() -} - // Constant. Defined as variables to simplify testing. var ( garbageCollectInterval = 30 * time.Second @@ -58,7 +54,7 @@ var ( ) type refcountedTransport struct { - tr quic.Transport + quic.Transport // Used to write packets directly around QUIC. packetConn net.PacketConn @@ -75,21 +71,13 @@ func (c *refcountedTransport) IncreaseCount() { c.mutex.Unlock() } -func (c *refcountedTransport) Transport() *quic.Transport { - return &c.tr -} - -func (c *refcountedTransport) Close() error { - return c.tr.Close() -} - func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) { // Safe because we called quic.OptimizeConn ourselves. return c.packetConn.WriteTo(b, addr) } func (c *refcountedTransport) LocalAddr() net.Addr { - return c.tr.Conn.LocalAddr() + return c.Transport.Conn.LocalAddr() } func (c *refcountedTransport) DecreaseCount() { @@ -119,7 +107,7 @@ type reuse struct { globalListeners map[int]*refcountedTransport // globalDialers contains transports that we've dialed out from. These transports are listening on 0.0.0.0 / :: // On Dial, transports are reused from this map if no transport is available in the globalListeners - // On Listen, transport are reused from this map if the requested port is 0, and then moved to globalListeners + // On Listen, transports are reused from this map if the requested port is 0, and then moved to globalListeners globalDialers map[int]*refcountedTransport statelessResetKey *quic.StatelessResetKey @@ -267,13 +255,15 @@ func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcoun if err != nil { return nil, err } - rconn := &refcountedTransport{tr: quic.Transport{ + tr := &refcountedTransport{Transport: quic.Transport{ Conn: conn, StatelessResetKey: r.statelessResetKey, - Tracer: r.metricsTracer, }, packetConn: conn} - r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = rconn - return rconn, nil + if r.metricsTracer != nil { + tr.Transport.Tracer = r.metricsTracer + } + r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr + return tr, nil } func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) { @@ -315,11 +305,13 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun return nil, err } localAddr := conn.LocalAddr().(*net.UDPAddr) - tr := &refcountedTransport{tr: quic.Transport{ + tr := &refcountedTransport{Transport: quic.Transport{ Conn: conn, StatelessResetKey: r.statelessResetKey, - Tracer: r.metricsTracer, }, packetConn: conn} + if r.metricsTracer != nil { + tr.Transport.Tracer = r.metricsTracer + } tr.IncreaseCount() diff --git a/p2p/transport/quicreuse/reuse_test.go b/p2p/transport/quicreuse/reuse_test.go index 5cd9fbf3fd..0b36a4c337 100644 --- a/p2p/transport/quicreuse/reuse_test.go +++ b/p2p/transport/quicreuse/reuse_test.go @@ -121,13 +121,13 @@ func TestReuseConnectionWhenListening(t *testing.T) { raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") require.NoError(t, err) - conn, err := reuse.TransportForDial("udp4", raddr) + tr, err := reuse.TransportForDial("udp4", raddr) require.NoError(t, err) - laddr := &net.UDPAddr{IP: net.IPv4zero, Port: conn.Transport().Conn.LocalAddr().(*net.UDPAddr).Port} + laddr := &net.UDPAddr{IP: net.IPv4zero, Port: tr.LocalAddr().(*net.UDPAddr).Port} lconn, err := reuse.TransportForListen("udp4", laddr) require.NoError(t, err) require.Equal(t, lconn.GetCount(), 2) - require.Equal(t, conn.GetCount(), 2) + require.Equal(t, tr.GetCount(), 2) } func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {