diff --git a/core/network/rcmgr.go b/core/network/rcmgr.go index b609257566..524a28a8c4 100644 --- a/core/network/rcmgr.go +++ b/core/network/rcmgr.go @@ -271,60 +271,57 @@ type ScopeStat struct { } // NullResourceManager is a stub for tests and initialization of default values -var NullResourceManager ResourceManager = &nullResourceManager{} - -type nullResourceManager struct{} -type nullScope struct{} - -var _ ResourceScope = (*nullScope)(nil) -var _ ResourceScopeSpan = (*nullScope)(nil) -var _ ServiceScope = (*nullScope)(nil) -var _ ProtocolScope = (*nullScope)(nil) -var _ PeerScope = (*nullScope)(nil) -var _ ConnManagementScope = (*nullScope)(nil) -var _ ConnScope = (*nullScope)(nil) -var _ StreamManagementScope = (*nullScope)(nil) -var _ StreamScope = (*nullScope)(nil) +type NullResourceManager struct{} + +var _ ResourceScope = (*NullScope)(nil) +var _ ResourceScopeSpan = (*NullScope)(nil) +var _ ServiceScope = (*NullScope)(nil) +var _ ProtocolScope = (*NullScope)(nil) +var _ PeerScope = (*NullScope)(nil) +var _ ConnManagementScope = (*NullScope)(nil) +var _ ConnScope = (*NullScope)(nil) +var _ StreamManagementScope = (*NullScope)(nil) +var _ StreamScope = (*NullScope)(nil) // NullScope is a stub for tests and initialization of default values -var NullScope = &nullScope{} +type NullScope struct{} -func (n *nullResourceManager) ViewSystem(f func(ResourceScope) error) error { - return f(NullScope) +func (n *NullResourceManager) ViewSystem(f func(ResourceScope) error) error { + return f(&NullScope{}) } -func (n *nullResourceManager) ViewTransient(f func(ResourceScope) error) error { - return f(NullScope) +func (n *NullResourceManager) ViewTransient(f func(ResourceScope) error) error { + return f(&NullScope{}) } -func (n *nullResourceManager) ViewService(svc string, f func(ServiceScope) error) error { - return f(NullScope) +func (n *NullResourceManager) ViewService(svc string, f func(ServiceScope) error) error { + return f(&NullScope{}) } -func (n *nullResourceManager) ViewProtocol(p protocol.ID, f func(ProtocolScope) error) error { - return f(NullScope) +func (n *NullResourceManager) ViewProtocol(p protocol.ID, f func(ProtocolScope) error) error { + return f(&NullScope{}) } -func (n *nullResourceManager) ViewPeer(p peer.ID, f func(PeerScope) error) error { - return f(NullScope) +func (n *NullResourceManager) ViewPeer(p peer.ID, f func(PeerScope) error) error { + return f(&NullScope{}) } -func (n *nullResourceManager) OpenConnection(dir Direction, usefd bool, endpoint multiaddr.Multiaddr) (ConnManagementScope, error) { - return NullScope, nil +func (n *NullResourceManager) OpenConnection(dir Direction, usefd bool, endpoint multiaddr.Multiaddr) (ConnManagementScope, error) { + return &NullScope{}, nil } -func (n *nullResourceManager) OpenStream(p peer.ID, dir Direction) (StreamManagementScope, error) { - return NullScope, nil +func (n *NullResourceManager) OpenStream(p peer.ID, dir Direction) (StreamManagementScope, error) { + return &NullScope{}, nil } -func (n *nullResourceManager) Close() error { +func (n *NullResourceManager) Close() error { return nil } -func (n *nullScope) ReserveMemory(size int, prio uint8) error { return nil } -func (n *nullScope) ReleaseMemory(size int) {} -func (n *nullScope) Stat() ScopeStat { return ScopeStat{} } -func (n *nullScope) BeginSpan() (ResourceScopeSpan, error) { return NullScope, nil } -func (n *nullScope) Done() {} -func (n *nullScope) Name() string { return "" } -func (n *nullScope) Protocol() protocol.ID { return "" } -func (n *nullScope) Peer() peer.ID { return "" } -func (n *nullScope) PeerScope() PeerScope { return NullScope } -func (n *nullScope) SetPeer(peer.ID) error { return nil } -func (n *nullScope) ProtocolScope() ProtocolScope { return NullScope } -func (n *nullScope) SetProtocol(proto protocol.ID) error { return nil } -func (n *nullScope) ServiceScope() ServiceScope { return NullScope } -func (n *nullScope) SetService(srv string) error { return nil } +func (n *NullScope) ReserveMemory(size int, prio uint8) error { return nil } +func (n *NullScope) ReleaseMemory(size int) {} +func (n *NullScope) Stat() ScopeStat { return ScopeStat{} } +func (n *NullScope) BeginSpan() (ResourceScopeSpan, error) { return &NullScope{}, nil } +func (n *NullScope) Done() {} +func (n *NullScope) Name() string { return "" } +func (n *NullScope) Protocol() protocol.ID { return "" } +func (n *NullScope) Peer() peer.ID { return "" } +func (n *NullScope) PeerScope() PeerScope { return &NullScope{} } +func (n *NullScope) SetPeer(peer.ID) error { return nil } +func (n *NullScope) ProtocolScope() ProtocolScope { return &NullScope{} } +func (n *NullScope) SetProtocol(proto protocol.ID) error { return nil } +func (n *NullScope) ServiceScope() ServiceScope { return &NullScope{} } +func (n *NullScope) SetService(srv string) error { return nil } diff --git a/go.mod b/go.mod index 66952dc1d7..3c6128f8d9 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/libp2p/zeroconf/v2 v2.2.0 github.com/lucas-clemente/quic-go v0.30.0 github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd - github.com/marten-seemann/webtransport-go v0.1.1 + github.com/marten-seemann/webtransport-go v0.2.0 github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b github.com/minio/sha256-simd v1.0.0 github.com/mr-tron/base58 v1.2.0 diff --git a/go.sum b/go.sum index 825792c33d..c9161da572 100644 --- a/go.sum +++ b/go.sum @@ -333,8 +333,8 @@ github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sN github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU= -github.com/marten-seemann/webtransport-go v0.1.1 h1:TnyKp3pEXcDooTaNn4s9dYpMJ7kMnTp7k5h+SgYP/mc= -github.com/marten-seemann/webtransport-go v0.1.1/go.mod h1:kBEh5+RSvOA4troP1vyOVBWK4MIMzDICXVrvCPrYcrM= +github.com/marten-seemann/webtransport-go v0.2.0 h1:987jPVqcyE3vF+CHNIxDhT0P21O+bI4fVF+0NoRujSo= +github.com/marten-seemann/webtransport-go v0.2.0/go.mod h1:XmnWYsWXaxUF7kjeIIzLWPyS+q0OcBY5vA64NuyK0ps= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 48015a4c61..487db8c7ed 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -189,5 +189,5 @@ func (c *conn) Stat() network.ConnStats { } func (c *conn) Scope() network.ConnScope { - return network.NullScope + return &network.NullScope{} } diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 2e931a660d..3fb0701322 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -370,5 +370,5 @@ func (pn *peernet) notifyAll(notification func(f network.Notifiee)) { } func (pn *peernet) ResourceManager() network.ResourceManager { - return network.NullResourceManager + return &network.NullResourceManager{} } diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 61303daba5..a78329d306 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -287,7 +287,7 @@ func (s *stream) transport() { } func (s *stream) Scope() network.StreamScope { - return network.NullScope + return &network.NullScope{} } func (s *stream) cancelWrite(err error) { diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index ece0986dfa..213f0fc658 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -178,7 +178,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) (*Swarm, } } if s.rcmgr == nil { - s.rcmgr = network.NullResourceManager + s.rcmgr = &network.NullResourceManager{} } s.dsync = newDialSync(s.dialWorkerLoop) diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 959d51c226..e39a78dd7b 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -45,7 +45,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - tpt, err := websocket.New(nil, network.NullResourceManager) + tpt, err := websocket.New(nil, &network.NullResourceManager{}) require.NoError(t, err) s, err := NewSwarm(id, ps, WithMultiaddrResolver(resolver)) require.NoError(t, err) @@ -81,7 +81,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { }) // Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out. - tpt, err := tcp.NewTCPTransport(nil, network.NullResourceManager) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 82c3952ef8..fcb7bf9c07 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -56,7 +56,7 @@ func TestAcceptSingleConn(t *testing.T) { ln := createListener(t, u) defer ln.Close() - cconn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + cconn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) sconn, err := ln.Accept() @@ -80,7 +80,7 @@ func TestAcceptMultipleConns(t *testing.T) { }() for i := 0; i < 10; i++ { - cconn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + cconn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) toClose = append(toClose, cconn) @@ -104,7 +104,7 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) { ln := createListener(t, u) defer ln.Close() - conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) errCh := make(chan error) @@ -143,7 +143,7 @@ func TestFailedUpgradeOnListen(t *testing.T) { errCh <- err }() - _, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + _, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.Error(err) // close the listener. @@ -177,7 +177,7 @@ func TestListenerClose(t *testing.T) { require.Contains(err.Error(), "use of closed network connection") // doesn't accept new connections when it is closed - _, err = dial(t, u, ln.Multiaddr(), peer.ID("1"), network.NullScope) + _, err = dial(t, u, ln.Multiaddr(), peer.ID("1"), &network.NullScope{}) require.Error(err) } @@ -189,7 +189,7 @@ func TestListenerCloseClosesQueued(t *testing.T) { var conns []transport.CapableConn for i := 0; i < 10; i++ { - conn, err := dial(t, upgrader, ln.Multiaddr(), id, network.NullScope) + conn, err := dial(t, upgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) conns = append(conns, conn) } @@ -249,7 +249,7 @@ func TestConcurrentAccept(t *testing.T) { go func() { defer wg.Done() - conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) if err != nil { errCh <- err return @@ -279,7 +279,7 @@ func TestAcceptQueueBacklogged(t *testing.T) { // setup AcceptQueueLength connections, but don't accept any of them var counter int32 // to be used atomically doDial := func() { - conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) atomic.AddInt32(&counter, 1) t.Cleanup(func() { conn.Close() }) @@ -315,7 +315,7 @@ func TestListenerConnectionGater(t *testing.T) { defer ln.Close() // no gating. - conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.False(conn.IsClosed()) _ = conn.Close() @@ -323,28 +323,28 @@ func TestListenerConnectionGater(t *testing.T) { // rejecting after handshake. testGater.BlockSecured(true) testGater.BlockAccept(false) - conn, err = dial(t, u, ln.Multiaddr(), "invalid", network.NullScope) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", &network.NullScope{}) require.Error(err) require.Nil(conn) // rejecting on accept will trigger firupgrader. testGater.BlockSecured(true) testGater.BlockAccept(true) - conn, err = dial(t, u, ln.Multiaddr(), "invalid", network.NullScope) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", &network.NullScope{}) require.Error(err) require.Nil(conn) // rejecting only on acceptance. testGater.BlockSecured(false) testGater.BlockAccept(true) - conn, err = dial(t, u, ln.Multiaddr(), "invalid", network.NullScope) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", &network.NullScope{}) require.Error(err) require.Nil(conn) // back to normal testGater.BlockSecured(false) testGater.BlockAccept(false) - conn, err = dial(t, u, ln.Multiaddr(), id, network.NullScope) + conn, err = dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.False(conn.IsClosed()) _ = conn.Close() @@ -366,7 +366,7 @@ func TestListenerResourceManagement(t *testing.T) { connScope.EXPECT().PeerScope(), ) - cconn, err := dial(t, upgrader, ln.Multiaddr(), id, network.NullScope) + cconn, err := dial(t, upgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(t, err) defer cconn.Close() @@ -384,7 +384,7 @@ func TestListenerResourceManagementDenied(t *testing.T) { ln := createListener(t, upgrader) rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Not(ln.Multiaddr())).Return(nil, errors.New("nope")) - _, err := dial(t, upgrader, ln.Multiaddr(), id, network.NullScope) + _, err := dial(t, upgrader, ln.Multiaddr(), id, &network.NullScope{}) require.Error(t, err) done := make(chan struct{}) diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 2ef60b82ac..ea19d2f2d6 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -90,7 +90,7 @@ func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, opts ...Option) } } if u.rcmgr == nil { - u.rcmgr = network.NullResourceManager + u.rcmgr = &network.NullResourceManager{} } return u, nil } diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go index 5d62e5f9c2..16188bd5bb 100644 --- a/p2p/net/upgrader/upgrader_test.go +++ b/p2p/net/upgrader/upgrader_test.go @@ -121,21 +121,21 @@ func TestOutboundConnectionGating(t *testing.T) { testGater := &testGater{} _, dialUpgrader := createUpgrader(t, upgrader.WithConnectionGater(testGater)) - conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, network.NullScope) + conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.NotNil(conn) _ = conn.Close() // blocking accepts doesn't affect the dialling side, only the listener. testGater.BlockAccept(true) - conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, network.NullScope) + conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.NotNil(conn) _ = conn.Close() // now let's block all connections after being secured. testGater.BlockSecured(true) - conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, network.NullScope) + conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.Error(err) require.Contains(err.Error(), "gater rejected connection") require.Nil(conn) @@ -153,7 +153,7 @@ func TestOutboundResourceManagement(t *testing.T) { gomock.InOrder( connScope.EXPECT().PeerScope(), connScope.EXPECT().SetPeer(id), - connScope.EXPECT().PeerScope().Return(network.NullScope), + connScope.EXPECT().PeerScope().Return(&network.NullScope{}), ) _, dialUpgrader := createUpgrader(t) conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, connScope) @@ -174,7 +174,7 @@ func TestOutboundResourceManagement(t *testing.T) { gomock.InOrder( connScope.EXPECT().PeerScope(), connScope.EXPECT().SetPeer(id), - connScope.EXPECT().PeerScope().Return(network.NullScope), + connScope.EXPECT().PeerScope().Return(&network.NullScope{}), connScope.EXPECT().Done(), ) _, dialUpgrader := createUpgrader(t) diff --git a/p2p/protocol/circuitv2/client/reservation_test.go b/p2p/protocol/circuitv2/client/reservation_test.go index c50a46160a..02a2f858db 100644 --- a/p2p/protocol/circuitv2/client/reservation_test.go +++ b/p2p/protocol/circuitv2/client/reservation_test.go @@ -89,7 +89,7 @@ func TestReservationFailures(t *testing.T) { host.SetStreamHandler(proto.ProtoIDv2Hop, tc.streamHandler) } - cl, err := libp2p.New(libp2p.ResourceManager(network.NullResourceManager)) + cl, err := libp2p.New(libp2p.ResourceManager(&network.NullResourceManager{})) require.NoError(t, err) defer cl.Close() _, err = client.Reserve(context.Background(), cl, peer.AddrInfo{ID: host.ID(), Addrs: host.Addrs()}) diff --git a/p2p/protocol/internal/circuitv1-deprecated/dial.go b/p2p/protocol/internal/circuitv1-deprecated/dial.go index acc1217c94..23c03b57ea 100644 --- a/p2p/protocol/internal/circuitv1-deprecated/dial.go +++ b/p2p/protocol/internal/circuitv1-deprecated/dial.go @@ -17,7 +17,7 @@ func (d *RelayTransport) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (t return nil, err } c.tagHop() - scope, _ := network.NullResourceManager.OpenConnection(network.DirOutbound, false, a) + scope, _ := (&network.NullResourceManager{}).OpenConnection(network.DirOutbound, false, a) return d.upgrader.Upgrade(ctx, d, c, network.DirOutbound, p, scope) } diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 4a70d80783..0a0af9fccb 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -189,7 +189,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, r return nil, err } if rcmgr == nil { - rcmgr = network.NullResourceManager + rcmgr = &network.NullResourceManager{} } qconfig := quicConfig.Clone() keyBytes, err := key.Raw() diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 61bb38e233..b41fe7bf1a 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -136,7 +136,7 @@ var _ transport.Transport = &TcpTransport{} // created. It represents an entire TCP stack (though it might not necessarily be). func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { if rcmgr == nil { - rcmgr = network.NullResourceManager + rcmgr = &network.NullResourceManager{} } tr := &TcpTransport{ upgrader: upgrader, diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index f44c16c83e..3d49a253aa 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -86,7 +86,7 @@ func TestResourceManager(t *testing.T) { scope := mocknetwork.NewMockConnManagementScope(ctrl) rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddr()).Return(scope, nil) scope.EXPECT().SetPeer(peerA) - scope.EXPECT().PeerScope().Return(network.NullScope).AnyTimes() // called by the upgrader + scope.EXPECT().PeerScope().Return(&network.NullScope{}).AnyTimes() // called by the upgrader conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA) require.NoError(t, err) scope.EXPECT().Done() diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index f1294a5702..b693ef4330 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -93,7 +93,7 @@ var _ transport.Transport = (*WebsocketTransport)(nil) func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) { if rcmgr == nil { - rcmgr = network.NullResourceManager + rcmgr = &network.NullResourceManager{} } t := &WebsocketTransport{ upgrader: u, diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 70e122d821..1c08baccb2 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -163,7 +163,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID } id, u := newSecureUpgrader(t) - tpt, err := New(u, network.NullResourceManager, WithTLSConfig(tlsConf)) + tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf)) if err != nil { t.Fatal(err) } @@ -246,7 +246,7 @@ func TestHostHeaderWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -265,7 +265,7 @@ func TestDialWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -288,7 +288,7 @@ func TestDialWssNoClientCert(t *testing.T) { require.Contains(t, serverMA.String(), "tls") _, u := newSecureUpgrader(t) - tpt, err := New(u, network.NullResourceManager) + tpt, err := New(u, &network.NullResourceManager{}) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSConfig(tlsConf)) } server, u := newUpgrader(t) - tpt, err := New(u, network.NullResourceManager, opts...) + tpt, err := New(u, &network.NullResourceManager{}, opts...) require.NoError(t, err) l, err := tpt.Listen(laddr) require.NoError(t, err) @@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) } _, u := newUpgrader(t) - tpt, err := New(u, network.NullResourceManager, opts...) + tpt, err := New(u, &network.NullResourceManager{}, opts...) require.NoError(t, err) c, err := tpt.Dial(context.Background(), l.Multiaddr(), server) require.NoError(t, err) @@ -378,7 +378,7 @@ func TestWebsocketConnection(t *testing.T) { func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, network.NullResourceManager) + tpt, err := New(u, &network.NullResourceManager{}) require.NoError(t, err) addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss") _, err = tpt.Listen(addr) @@ -387,7 +387,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { func TestWebsocketListenSecureAndInsecure(t *testing.T) { serverID, serverUpgrader := newUpgrader(t) - server, err := New(serverUpgrader, network.NullResourceManager, WithTLSConfig(generateTLSConfig(t))) + server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t))) require.NoError(t, err) lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) @@ -397,7 +397,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("insecure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, network.NullResourceManager, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -414,7 +414,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("secure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, network.NullResourceManager, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -432,7 +432,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { func TestConcurrentClose(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, network.NullResourceManager) + tpt, err := New(u, &network.NullResourceManager{}) require.NoError(t, err) l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { @@ -470,7 +470,7 @@ func TestConcurrentClose(t *testing.T) { func TestWriteZero(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, network.NullResourceManager) + tpt, err := New(u, &network.NullResourceManager{}) if err != nil { t.Fatal(err) } diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 059258c752..d8a366a0eb 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -27,7 +27,7 @@ func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } type conn struct { *connSecurityMultiaddrs - transport tpt.Transport + transport *transport session *webtransport.Session scope network.ConnScope @@ -35,7 +35,7 @@ type conn struct { var _ tpt.CapableConn = &conn{} -func newConn(tr tpt.Transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnScope) *conn { +func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnScope) *conn { return &conn{ connSecurityMultiaddrs: sconn, transport: tr, @@ -60,7 +60,18 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { return &stream{str}, nil } -func (c *conn) Close() error { return c.session.Close() } +func (c *conn) allowWindowIncrease(size uint64) bool { + return c.scope.ReserveMemory(int(size), network.ReservationPriorityMedium) == nil +} + +// Close closes the connection. +// It must be called even if the peer closed the connection in order for +// garbage collection to properly work in this package. +func (c *conn) Close() error { + c.transport.removeConn(c.session) + return c.session.Close() +} + func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } func (c *conn) Scope() network.ConnScope { return c.scope } func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 15905a95a6..a5fae40da3 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -157,8 +157,10 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { return } + conn := newConn(l.transport, sess, sconn, connScope) + l.transport.addConn(sess, conn) select { - case l.queue <- newConn(l.transport, sess, sconn, connScope): + case l.queue <- conn: default: log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) sess.Close() diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index a65f723458..922e915ec3 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -82,6 +82,9 @@ type transport struct { tlsClientConf *tls.Config noise *noise.Transport + + connMx sync.Mutex + conns map[uint64]*conn // using quic-go's ConnectionTracingKey as map key } var _ tpt.Transport = &transport{} @@ -94,13 +97,14 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - quicConfig: &quic.Config{}, - } + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + conns: map[uint64]*conn{}, + } + t.quicConfig = &quic.Config{AllowConnectionWindowIncrease: t.allowWindowIncrease} for _, opt := range opts { if err := opt(t); err != nil { return nil, err @@ -157,8 +161,9 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp scope.Done() return nil, fmt.Errorf("secured connection gated") } - - return newConn(t, sess, sconn, scope), nil + conn := newConn(t, sess, sconn, scope) + t.addConn(sess, conn) + return conn, nil } func (t *transport) dial(ctx context.Context, addr string, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { @@ -313,6 +318,29 @@ func (t *transport) Close() error { return nil } +func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool { + t.connMx.Lock() + defer t.connMx.Unlock() + + c, ok := t.conns[conn.Context().Value(quic.ConnectionTracingKey).(uint64)] + if !ok { + return false + } + return c.allowWindowIncrease(size) +} + +func (t *transport) addConn(sess *webtransport.Session, c *conn) { + t.connMx.Lock() + t.conns[sess.Context().Value(quic.ConnectionTracingKey).(uint64)] = c + t.connMx.Unlock() +} + +func (t *transport) removeConn(sess *webtransport.Session) { + t.connMx.Lock() + delete(t.conns, sess.Context().Value(quic.ConnectionTracingKey).(uint64)) + t.connMx.Unlock() +} + // extractSNI returns what the SNI should be for the given maddr. If there is an // SNI component in the multiaddr, then it will be returned and // foundSniComponent will be true. If there's no SNI component, but there is a diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index ea8d276a21..de4cf2263c 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -1,6 +1,7 @@ package libp2pwebtransport_test import ( + "bytes" "context" "crypto/ecdsa" "crypto/elliptic" @@ -14,7 +15,10 @@ import ( "io" "math/big" "net" + "os" + "runtime" "strings" + "sync/atomic" "testing" "time" @@ -26,6 +30,7 @@ import ( libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/golang/mock/gomock" + quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" @@ -88,7 +93,7 @@ func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr { func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -98,7 +103,7 @@ func TestTransport(t *testing.T) { addrChan := make(chan ma.Multiaddr) go func() { _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) + tr2, err := libp2pwebtransport.New(clientKey, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -134,7 +139,7 @@ func TestTransport(t *testing.T) { func TestHashVerification(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -147,7 +152,7 @@ func TestHashVerification(t *testing.T) { }() _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) + tr2, err := libp2pwebtransport.New(clientKey, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -185,7 +190,7 @@ func TestCanDial(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -211,7 +216,7 @@ func TestListenAddrValidity(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -228,7 +233,7 @@ func TestListenAddrValidity(t *testing.T) { func TestListenerAddrs(t *testing.T) { _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -266,7 +271,7 @@ func TestResourceManagerDialing(t *testing.T) { func TestResourceManagerListening(t *testing.T) { clientID, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -345,7 +350,7 @@ func TestConnectionGaterDialing(t *testing.T) { connGater := NewMockConnectionGater(ctrl) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -356,7 +361,7 @@ func TestConnectionGaterDialing(t *testing.T) { require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) }) _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, connGater, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -369,7 +374,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { connGater := NewMockConnectionGater(ctrl) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, connGater, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -382,7 +387,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { }) _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -395,7 +400,7 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { connGater := NewMockConnectionGater(ctrl) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, connGater, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -403,7 +408,7 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { defer ln.Close() clientID, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -461,7 +466,7 @@ func TestStaticTLSConf(t *testing.T) { tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour)) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf)) + tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf)) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -471,7 +476,7 @@ func TestStaticTLSConf(t *testing.T) { t.Run("fails when the certificate is invalid", func(t *testing.T) { _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -485,7 +490,7 @@ func TestStaticTLSConf(t *testing.T) { t.Run("fails when dialing with a wrong certhash", func(t *testing.T) { _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -500,7 +505,7 @@ func TestStaticTLSConf(t *testing.T) { store := x509.NewCertPool() store.AddCert(tlsConf.Certificates[0].Leaf) tlsConf := &tls.Config{RootCAs: store} - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf)) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf)) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -513,7 +518,7 @@ func TestStaticTLSConf(t *testing.T) { func TestAcceptQueueFilledUp(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -523,7 +528,7 @@ func TestAcceptQueueFilledUp(t *testing.T) { newConn := func() (tpt.CapableConn, error) { t.Helper() _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() return cl.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -553,7 +558,7 @@ func TestSNIIsSent(t *testing.T) { return tlsConf, nil }, } - tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf)) + tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf)) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -561,7 +566,7 @@ func TestSNIIsSent(t *testing.T) { require.NoError(t, err) _, key2 := newIdentity(t) - clientTr, err := libp2pwebtransport.New(key2, nil, network.NullResourceManager) + clientTr, err := libp2pwebtransport.New(key2, nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -582,5 +587,130 @@ func TestSNIIsSent(t *testing.T) { case <-time.After(time.Minute): t.Fatalf("Expected to get server name") } +} + +type reportingRcmgr struct { + network.NullResourceManager + report chan<- int +} + +func (m *reportingRcmgr) OpenConnection(dir network.Direction, usefd bool, endpoint ma.Multiaddr) (network.ConnManagementScope, error) { + return &reportingScope{report: m.report}, nil +} + +type reportingScope struct { + network.NullScope + report chan<- int +} + +func (s *reportingScope) ReserveMemory(size int, _ uint8) error { + s.report <- size + return nil +} + +func TestFlowControlWindowIncrease(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test is flaky on Windows") + } + + rtt := 10 * time.Millisecond + timeout := 5 * time.Second + + if os.Getenv("CI") != "" { + rtt = 40 * time.Millisecond + timeout = 15 * time.Second + } + + serverID, serverKey := newIdentity(t) + serverWindowIncreases := make(chan int, 100) + serverRcmgr := &reportingRcmgr{report: serverWindowIncreases} + tr, err := libp2pwebtransport.New(serverKey, nil, serverRcmgr) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + str, err := conn.AcceptStream() + require.NoError(t, err) + _, err = io.CopyBuffer(str, str, make([]byte, 2<<10)) + require.NoError(t, err) + str.CloseWrite() + }() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: ln.Addr().String(), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + }) + require.NoError(t, err) + defer proxy.Close() + _, clientKey := newIdentity(t) + clientWindowIncreases := make(chan int, 100) + clientRcmgr := &reportingRcmgr{report: clientWindowIncreases} + tr2, err := libp2pwebtransport.New(clientKey, nil, clientRcmgr) + require.NoError(t, err) + defer tr2.(io.Closer).Close() + + var addr ma.Multiaddr + for _, comp := range ma.Split(ln.Multiaddr()) { + if _, err := comp.ValueForProtocol(ma.P_UDP); err == nil { + addr = addr.Encapsulate(ma.StringCast(fmt.Sprintf("/udp/%d", proxy.LocalPort()))) + continue + } + if addr == nil { + addr = comp + continue + } + addr = addr.Encapsulate(comp) + } + + conn, err := tr2.Dial(context.Background(), addr, serverID) + require.NoError(t, err) + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + var increasesDone uint32 // to be used atomically + go func() { + for { + _, err := str.Write(bytes.Repeat([]byte{0x42}, 1<<10)) + require.NoError(t, err) + if atomic.LoadUint32(&increasesDone) > 0 { + str.CloseWrite() + return + } + } + }() + done := make(chan struct{}) + go func() { + defer close(done) + _, err := io.ReadAll(str) + require.NoError(t, err) + }() + + var numServerIncreases, numClientIncreases int + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case <-serverWindowIncreases: + numServerIncreases++ + case <-clientWindowIncreases: + numClientIncreases++ + case <-timer.C: + t.Fatalf("didn't receive enough window increases (client: %d, server: %d)", numClientIncreases, numServerIncreases) + } + if numClientIncreases >= 1 && numServerIncreases >= 1 { + atomic.AddUint32(&increasesDone, 1) + break + } + } + + select { + case <-done: + case <-time.After(timeout): + t.Fatal("timeout") + } }