Skip to content

Commit

Permalink
Embed quic.Transport
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Jul 12, 2023
1 parent 6f1d91c commit 4dc8399
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 40 deletions.
15 changes: 12 additions & 3 deletions p2p/transport/quicreuse/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (re
if err != nil {
return nil, err
}
return &singleOwnerTransport{tr: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil
tr := &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}
if c.metricsTracer != nil {
tr.Transport.Tracer = c.metricsTracer
}
return tr, nil
}

func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) {
Expand Down Expand Up @@ -193,7 +197,7 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
if err != nil {
return nil, err
}
conn, err := tr.Transport().Dial(ctx, naddr, tlsConf, quicConf)
conn, err := tr.Dial(ctx, naddr, tlsConf, quicConf)
if err != nil {
tr.DecreaseCount()
return nil, err
Expand Down Expand Up @@ -221,7 +225,12 @@ func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refC
if err != nil {
return nil, err
}
return &singleOwnerTransport{tr: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil
tr := &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}
if c.metricsTracer != nil {
tr.Transport.Tracer = c.metricsTracer
}

return tr, nil
}

func (c *ConnManager) Protocols() []int {
Expand Down
8 changes: 4 additions & 4 deletions p2p/transport/quicreuse/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {
require.NoError(t, err)
quicTr, err := cm.transportForListen(netw, naddr)
require.NoError(t, err)
defer quicTr.Transport().Close()
if _, ok := quicTr.Transport().Conn.(quic.OOBCapablePacketConn); !ok {
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}
Expand Down Expand Up @@ -163,8 +163,8 @@ func TestConnectionPassedToQUICForDialing(t *testing.T) {
quicTr, err := cm.TransportForDial(netw, naddr)

require.NoError(t, err, "dial error")
defer quicTr.Transport().Conn.Close()
if _, ok := quicTr.Transport().Conn.(quic.OOBCapablePacketConn); !ok {
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}
Expand Down
2 changes: 1 addition & 1 deletion p2p/transport/quicreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func newConnListener(c refCountedQuicTransport, quicConfig *quic.Config, enableD
}
quicConf := quicConfig.Clone()
quicConf.AllowConnectionWindowIncrease = cl.allowWindowIncrease
ln, err := c.Transport().Listen(tlsConf, quicConf)
ln, err := c.Listen(tlsConf, quicConf)
if err != nil {
return nil, err
}
Expand Down
50 changes: 21 additions & 29 deletions p2p/transport/quicreuse/reuse.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package quicreuse

import (
"context"
"crypto/tls"
"net"
"sync"
"time"
Expand All @@ -11,7 +13,6 @@ import (
)

type refCountedQuicTransport interface {
Transport() *quic.Transport
WriteTo([]byte, net.Addr) (int, error)
LocalAddr() net.Addr

Expand All @@ -20,45 +21,40 @@ type refCountedQuicTransport interface {
// count conn reference
DecreaseCount()
IncreaseCount()

Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error)
}

type singleOwnerTransport struct {
tr quic.Transport
quic.Transport

// Used to write packets directly around QUIC.
packetConn net.PacketConn
}

func (c *singleOwnerTransport) IncreaseCount() {}
func (c *singleOwnerTransport) DecreaseCount() {
c.tr.Close()
}

func (c *singleOwnerTransport) Transport() *quic.Transport {
return &c.tr
c.Transport.Close()
}

func (c *singleOwnerTransport) LocalAddr() net.Addr {
return c.tr.Conn.LocalAddr()
return c.Transport.Conn.LocalAddr()
}

func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
// Safe because we called quic.OptimizeConn ourselves.
return c.packetConn.WriteTo(b, addr)
}

func (c *singleOwnerTransport) Close() error {
return c.tr.Close()
}

// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
)

type refcountedTransport struct {
tr quic.Transport
quic.Transport

// Used to write packets directly around QUIC.
packetConn net.PacketConn
Expand All @@ -75,21 +71,13 @@ func (c *refcountedTransport) IncreaseCount() {
c.mutex.Unlock()
}

func (c *refcountedTransport) Transport() *quic.Transport {
return &c.tr
}

func (c *refcountedTransport) Close() error {
return c.tr.Close()
}

func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
// Safe because we called quic.OptimizeConn ourselves.
return c.packetConn.WriteTo(b, addr)
}

func (c *refcountedTransport) LocalAddr() net.Addr {
return c.tr.Conn.LocalAddr()
return c.Transport.Conn.LocalAddr()
}

func (c *refcountedTransport) DecreaseCount() {
Expand Down Expand Up @@ -119,7 +107,7 @@ type reuse struct {
globalListeners map[int]*refcountedTransport
// globalDialers contains transports that we've dialed out from. These transports are listening on 0.0.0.0 / ::
// On Dial, transports are reused from this map if no transport is available in the globalListeners
// On Listen, transport are reused from this map if the requested port is 0, and then moved to globalListeners
// On Listen, transports are reused from this map if the requested port is 0, and then moved to globalListeners
globalDialers map[int]*refcountedTransport

statelessResetKey *quic.StatelessResetKey
Expand Down Expand Up @@ -267,13 +255,15 @@ func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcoun
if err != nil {
return nil, err
}
rconn := &refcountedTransport{tr: quic.Transport{
tr := &refcountedTransport{Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
Tracer: r.metricsTracer,
}, packetConn: conn}
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = rconn
return rconn, nil
if r.metricsTracer != nil {
tr.Transport.Tracer = r.metricsTracer
}
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr
return tr, nil
}

func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
Expand Down Expand Up @@ -315,11 +305,13 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
return nil, err
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
tr := &refcountedTransport{tr: quic.Transport{
tr := &refcountedTransport{Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
Tracer: r.metricsTracer,
}, packetConn: conn}
if r.metricsTracer != nil {
tr.Transport.Tracer = r.metricsTracer
}

tr.IncreaseCount()

Expand Down
6 changes: 3 additions & 3 deletions p2p/transport/quicreuse/reuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ func TestReuseConnectionWhenListening(t *testing.T) {

raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.TransportForDial("udp4", raddr)
tr, err := reuse.TransportForDial("udp4", raddr)
require.NoError(t, err)
laddr := &net.UDPAddr{IP: net.IPv4zero, Port: conn.Transport().Conn.LocalAddr().(*net.UDPAddr).Port}
laddr := &net.UDPAddr{IP: net.IPv4zero, Port: tr.LocalAddr().(*net.UDPAddr).Port}
lconn, err := reuse.TransportForListen("udp4", laddr)
require.NoError(t, err)
require.Equal(t, lconn.GetCount(), 2)
require.Equal(t, conn.GetCount(), 2)
require.Equal(t, tr.GetCount(), 2)
}

func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
Expand Down

0 comments on commit 4dc8399

Please sign in to comment.