Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport tests: refactor workers in TestMoreStreamsThanOurLimits #2472

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 101 additions & 70 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func TestMoreStreamsThanOurLimits(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
listener := tc.HostGenerator(t, TransportTestCaseOpts{})
dialer := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
dialer := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, NoRcmgr: true})
defer listener.Close()
defer dialer.Close()

Expand All @@ -370,101 +370,132 @@ func TestMoreStreamsThanOurLimits(t *testing.T) {
var handledStreams atomic.Int32
var sawFirstErr atomic.Bool

semaphore := make(chan struct{}, streamCount)
// Start with a single stream at a time. If that works, we'll increase the number of concurrent streams.
semaphore <- struct{}{}
workQueue := make(chan struct{}, streamCount)
for i := 0; i < streamCount; i++ {
workQueue <- struct{}{}
}
close(workQueue)

listener.SetStreamHandler("echo", func(s network.Stream) {
// Wait a bit so that we have more parallel streams open at the same time
time.Sleep(time.Millisecond * 10)
io.Copy(s, s)
s.Close()
})

wg := sync.WaitGroup{}
wg.Add(streamCount)
errCh := make(chan error, 1)
var completedStreams atomic.Int32
for i := 0; i < streamCount; i++ {
go func() {
<-semaphore
var didErr bool
defer wg.Done()
defer completedStreams.Add(1)
defer func() {
select {
case semaphore <- struct{}{}:
default:
}
if !didErr && !sawFirstErr.Load() {
// No error! We can add one more stream to our concurrency limit.
select {
case semaphore <- struct{}{}:
default:
}
}
}()

var s network.Stream
var err error
// maxRetries is an arbitrary retry amount if there's any error.
maxRetries := streamCount * 4
shouldRetry := func(err error) bool {
didErr = true
sawFirstErr.Store(true)
maxRetries--
if maxRetries == 0 || len(errCh) > 0 {
select {
case errCh <- errors.New("max retries exceeded"):
default:
}
return false
}
return true
const maxWorkerCount = streamCount
workerCount := 4

var startWorker func(workerIdx int)
startWorker = func(workerIdx int) {
wg.Add(1)
defer wg.Done()
for {
_, ok := <-workQueue
if !ok {
return
}

for {
s, err = dialer.NewStream(context.Background(), listener.ID(), "echo")
if err != nil {
if shouldRetry(err) {
time.Sleep(50 * time.Millisecond)
continue
// Inline function so we can use defer
func() {
var didErr bool
defer completedStreams.Add(1)
defer func() {
// Only the first worker adds more workers
if workerIdx == 0 && !didErr && !sawFirstErr.Load() {
nextWorkerCount := workerCount * 2
if nextWorkerCount < maxWorkerCount {
for i := workerCount; i < nextWorkerCount; i++ {
go startWorker(i)
}
workerCount = nextWorkerCount
}
}
}
err = func(s network.Stream) error {
defer s.Close()
_, err = s.Write([]byte("hello"))
if err != nil {
return err
}

err = s.CloseWrite()
if err != nil {
return err
}()

var s network.Stream
var err error
// maxRetries is an arbitrary retry amount if there's any error.
maxRetries := streamCount * 4
shouldRetry := func(err error) bool {
didErr = true
sawFirstErr.Store(true)
maxRetries--
if maxRetries == 0 || len(errCh) > 0 {
select {
case errCh <- errors.New("max retries exceeded"):
default:
}
return false
}
return true
}

b, err := io.ReadAll(s)
for {
s, err = dialer.NewStream(context.Background(), listener.ID(), "echo")
if err != nil {
return err
if shouldRetry(err) {
time.Sleep(50 * time.Millisecond)
continue
}
}
if !bytes.Equal(b, []byte("hello")) {
return errors.New("received data does not match sent data")
err = func(s network.Stream) error {
defer s.Close()
err = s.SetDeadline(time.Now().Add(100 * time.Millisecond))
if err != nil {
return err
}

_, err = s.Write([]byte("hello"))
if err != nil {
return err
}

err = s.CloseWrite()
if err != nil {
return err
}

b, err := io.ReadAll(s)
if err != nil {
return err
}
if !bytes.Equal(b, []byte("hello")) {
return errors.New("received data does not match sent data")
}
handledStreams.Add(1)

return nil
}(s)
if err != nil && shouldRetry(err) {
time.Sleep(50 * time.Millisecond)
continue
}
handledStreams.Add(1)
return

return nil
}(s)
if err != nil && shouldRetry(err) {
time.Sleep(50 * time.Millisecond)
continue
}
return
}
}()
}()
}
}

// Create any initial parallel workers
for i := 1; i < workerCount; i++ {
go startWorker(i)
}

// Start the first worker
startWorker(0)

wg.Wait()
close(errCh)

require.NoError(t, <-errCh)
require.Equal(t, streamCount, int(handledStreams.Load()))
require.True(t, sawFirstErr.Load())
})
}
}
Expand Down
Loading