Skip to content

Commit

Permalink
webtransport: use the rcmgr to control flow control window increases
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Oct 24, 2022
1 parent c0a0aa0 commit b50b460
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 16 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ require (
github.com/libp2p/zeroconf/v2 v2.2.0
github.com/lucas-clemente/quic-go v0.30.0
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd
github.com/marten-seemann/webtransport-go v0.1.1
github.com/marten-seemann/webtransport-go v0.2.0
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b
github.com/minio/sha256-simd v1.0.0
github.com/mr-tron/base58 v1.2.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sN
github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI=
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk=
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU=
github.com/marten-seemann/webtransport-go v0.1.1 h1:TnyKp3pEXcDooTaNn4s9dYpMJ7kMnTp7k5h+SgYP/mc=
github.com/marten-seemann/webtransport-go v0.1.1/go.mod h1:kBEh5+RSvOA4troP1vyOVBWK4MIMzDICXVrvCPrYcrM=
github.com/marten-seemann/webtransport-go v0.2.0 h1:987jPVqcyE3vF+CHNIxDhT0P21O+bI4fVF+0NoRujSo=
github.com/marten-seemann/webtransport-go v0.2.0/go.mod h1:XmnWYsWXaxUF7kjeIIzLWPyS+q0OcBY5vA64NuyK0ps=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
Expand Down
17 changes: 14 additions & 3 deletions p2p/transport/webtransport/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote }
type conn struct {
*connSecurityMultiaddrs

transport tpt.Transport
transport *transport
session *webtransport.Session

scope network.ConnScope
}

var _ tpt.CapableConn = &conn{}

func newConn(tr tpt.Transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnScope) *conn {
func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnScope) *conn {
return &conn{
connSecurityMultiaddrs: sconn,
transport: tr,
Expand All @@ -60,7 +60,18 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) {
return &stream{str}, nil
}

func (c *conn) Close() error { return c.session.Close() }
func (c *conn) allowWindowIncrease(size uint64) bool {
return c.scope.ReserveMemory(int(size), network.ReservationPriorityMedium) == nil
}

// Close closes the connection.
// It must be called even if the peer closed the connection in order for
// garbage collection to properly work in this package.
func (c *conn) Close() error {
c.transport.removeConn(c.session)
return c.session.Close()
}

func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }
4 changes: 3 additions & 1 deletion p2p/transport/webtransport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
return
}

conn := newConn(l.transport, sess, sconn, connScope)
l.transport.addConn(sess, conn)
select {
case l.queue <- newConn(l.transport, sess, sconn, connScope):
case l.queue <- conn:
default:
log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.Close()
Expand Down
46 changes: 37 additions & 9 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ type transport struct {
tlsClientConf *tls.Config

noise *noise.Transport

connMx sync.Mutex
conns map[uint64]*conn // using quic-go's ConnectionTracingKey as map key
}

var _ tpt.Transport = &transport{}
Expand All @@ -94,13 +97,14 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa
return nil, err
}
t := &transport{
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
quicConfig: &quic.Config{},
}
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
conns: map[uint64]*conn{},
}
t.quicConfig = &quic.Config{AllowConnectionWindowIncrease: t.allowWindowIncrease}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
Expand Down Expand Up @@ -157,8 +161,9 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
scope.Done()
return nil, fmt.Errorf("secured connection gated")
}

return newConn(t, sess, sconn, scope), nil
conn := newConn(t, sess, sconn, scope)
t.addConn(sess, conn)
return conn, nil
}

func (t *transport) dial(ctx context.Context, addr string, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
Expand Down Expand Up @@ -313,6 +318,29 @@ func (t *transport) Close() error {
return nil
}

func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool {
t.connMx.Lock()
defer t.connMx.Unlock()

c, ok := t.conns[conn.Context().Value(quic.ConnectionTracingKey).(uint64)]
if !ok {
return false
}
return c.allowWindowIncrease(size)
}

func (t *transport) addConn(sess *webtransport.Session, c *conn) {
t.connMx.Lock()
t.conns[sess.Context().Value(quic.ConnectionTracingKey).(uint64)] = c
t.connMx.Unlock()
}

func (t *transport) removeConn(sess *webtransport.Session) {
t.connMx.Lock()
delete(t.conns, sess.Context().Value(quic.ConnectionTracingKey).(uint64))
t.connMx.Unlock()
}

// extractSNI returns what the SNI should be for the given maddr. If there is an
// SNI component in the multiaddr, then it will be returned and
// foundSniComponent will be true. If there's no SNI component, but there is a
Expand Down
130 changes: 130 additions & 0 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package libp2pwebtransport_test

import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
Expand All @@ -14,7 +15,10 @@ import (
"io"
"math/big"
"net"
"os"
"runtime"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -26,6 +30,7 @@ import (
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"

"github.com/golang/mock/gomock"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
Expand Down Expand Up @@ -582,5 +587,130 @@ func TestSNIIsSent(t *testing.T) {
case <-time.After(time.Minute):
t.Fatalf("Expected to get server name")
}
}

type reportingRcmgr struct {
network.NullResourceManager
report chan<- int
}

func (m *reportingRcmgr) OpenConnection(dir network.Direction, usefd bool, endpoint ma.Multiaddr) (network.ConnManagementScope, error) {
return &reportingScope{report: m.report}, nil
}

type reportingScope struct {
network.NullScope
report chan<- int
}

func (s *reportingScope) ReserveMemory(size int, _ uint8) error {
s.report <- size
return nil
}

func TestFlowControlWindowIncrease(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("this test is flaky on Windows")
}

rtt := 10 * time.Millisecond
timeout := 5 * time.Second

if os.Getenv("CI") != "" {
rtt = 40 * time.Millisecond
timeout = 15 * time.Second
}

serverID, serverKey := newIdentity(t)
serverWindowIncreases := make(chan int, 100)
serverRcmgr := &reportingRcmgr{report: serverWindowIncreases}
tr, err := libp2pwebtransport.New(serverKey, nil, serverRcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()

go func() {
conn, err := ln.Accept()
require.NoError(t, err)
str, err := conn.AcceptStream()
require.NoError(t, err)
_, err = io.CopyBuffer(str, str, make([]byte, 2<<10))
require.NoError(t, err)
str.CloseWrite()
}()

proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
defer proxy.Close()

_, clientKey := newIdentity(t)
clientWindowIncreases := make(chan int, 100)
clientRcmgr := &reportingRcmgr{report: clientWindowIncreases}
tr2, err := libp2pwebtransport.New(clientKey, nil, clientRcmgr)
require.NoError(t, err)
defer tr2.(io.Closer).Close()

var addr ma.Multiaddr
for _, comp := range ma.Split(ln.Multiaddr()) {
if _, err := comp.ValueForProtocol(ma.P_UDP); err == nil {
addr = addr.Encapsulate(ma.StringCast(fmt.Sprintf("/udp/%d", proxy.LocalPort())))
continue
}
if addr == nil {
addr = comp
continue
}
addr = addr.Encapsulate(comp)
}

conn, err := tr2.Dial(context.Background(), addr, serverID)
require.NoError(t, err)
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
var increasesDone uint32 // to be used atomically
go func() {
for {
_, err := str.Write(bytes.Repeat([]byte{0x42}, 1<<10))
require.NoError(t, err)
if atomic.LoadUint32(&increasesDone) > 0 {
str.CloseWrite()
return
}
}
}()
done := make(chan struct{})
go func() {
defer close(done)
_, err := io.ReadAll(str)
require.NoError(t, err)
}()

var numServerIncreases, numClientIncreases int
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-serverWindowIncreases:
numServerIncreases++
case <-clientWindowIncreases:
numClientIncreases++
case <-timer.C:
t.Fatalf("didn't receive enough window increases (client: %d, server: %d)", numClientIncreases, numServerIncreases)
}
if numClientIncreases >= 1 && numServerIncreases >= 1 {
atomic.AddUint32(&increasesDone, 1)
break
}
}

select {
case <-done:
case <-time.After(timeout):
t.Fatal("timeout")
}
}

0 comments on commit b50b460

Please sign in to comment.