diff --git a/deadline.go b/deadline.go new file mode 100644 index 0000000..dd2dfaf --- /dev/null +++ b/deadline.go @@ -0,0 +1,80 @@ +// Copied from the go standard library. +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE-BSD file. + +package multiplex + +import ( + "sync" + "time" +) + +// pipeDeadline is an abstraction for handling timeouts. +type pipeDeadline struct { + mu sync.Mutex // Guards timer and cancel + timer *time.Timer + cancel chan struct{} // Must be non-nil +} + +func makePipeDeadline() pipeDeadline { + return pipeDeadline{cancel: make(chan struct{})} +} + +// set sets the point in time when the deadline will time out. +// A timeout event is signaled by closing the channel returned by waiter. +// Once a timeout has occurred, the deadline can be refreshed by specifying a +// t value in the future. +// +// A zero value for t prevents timeout. +func (d *pipeDeadline) set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel // Wait for the timer callback to finish and close cancel + } + d.timer = nil + + // Time is zero, then there is no deadline. + closed := isClosedChan(d.cancel) + if t.IsZero() { + if closed { + d.cancel = make(chan struct{}) + } + return + } + + // Time in the future, setup a timer to cancel in the future. + if dur := time.Until(t); dur > 0 { + if closed { + d.cancel = make(chan struct{}) + } + d.timer = time.AfterFunc(dur, func() { + close(d.cancel) + }) + return + } + + // Time in the past, so close immediately. + if !closed { + close(d.cancel) + } +} + +// wait returns a channel that is closed when the deadline is exceeded. +func (d *pipeDeadline) wait() chan struct{} { + d.mu.Lock() + defer d.mu.Unlock() + return d.cancel +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} diff --git a/multiplex.go b/multiplex.go index 3fa48c9..5869fd1 100644 --- a/multiplex.go +++ b/multiplex.go @@ -34,6 +34,9 @@ var ErrTwoInitiators = errors.New("two initiators") // In this case, we close the connection to be safe. var ErrInvalidState = errors.New("received an unexpected message from the peer") +var errTimeout = timeout{} +var errStreamClosed = errors.New("stream closed") + var ( NewStreamTimeout = time.Minute ResetStreamTimeout = 2 * time.Minute @@ -41,6 +44,20 @@ var ( WriteCoalesceDelay = 100 * time.Microsecond ) +type timeout struct{} + +func (_ timeout) Error() string { + return "i/o deadline exceeded" +} + +func (_ timeout) Temporary() bool { + return true +} + +func (_ timeout) Timeout() bool { + return true +} + // +1 for initiator const ( newStreamTag = 0 @@ -93,11 +110,13 @@ func NewMultiplex(con net.Conn, initiator bool) *Multiplex { func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) { s = &Stream{ - id: id, - name: name, - dataIn: make(chan []byte, 8), - reset: make(chan struct{}), - mp: mp, + id: id, + name: name, + dataIn: make(chan []byte, 8), + reset: make(chan struct{}), + rDeadline: makePipeDeadline(), + wDeadline: makePipeDeadline(), + mp: mp, } s.closedLocal, s.doCloseLocal = context.WithCancel(context.Background()) @@ -148,7 +167,7 @@ func (mp *Multiplex) IsClosed() bool { } } -func (mp *Multiplex) sendMsg(ctx context.Context, header uint64, data []byte) error { +func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) error { buf := pool.Get(len(data) + 20) n := 0 @@ -161,8 +180,8 @@ func (mp *Multiplex) sendMsg(ctx context.Context, header uint64, data []byte) er return nil case <-mp.shutdown: return ErrShutdown - case <-ctx.Done(): - return ctx.Err() + case <-done: + return errTimeout } } @@ -295,7 +314,7 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) { ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout) defer cancel() - err := mp.sendMsg(ctx, header, []byte(name)) + err := mp.sendMsg(ctx.Done(), header, []byte(name)) if err != nil { return nil, err } @@ -410,6 +429,8 @@ func (mp *Multiplex) handleIncoming() { msch.clLock.Unlock() + msch.cancelDeadlines() + mp.chLock.Lock() delete(mp.channels, ch) mp.chLock.Unlock() @@ -435,6 +456,7 @@ func (mp *Multiplex) handleIncoming() { msch.clLock.Unlock() if cleanup { + msch.cancelDeadlines() mp.chLock.Lock() delete(mp.channels, ch) mp.chLock.Unlock() @@ -505,7 +527,7 @@ func (mp *Multiplex) sendResetMsg(header uint64, hard bool) { ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) defer cancel() - err := mp.sendMsg(ctx, header, nil) + err := mp.sendMsg(ctx.Done(), header, nil) if err != nil && !mp.isShutdown() { if hard { log.Warningf("error sending reset message: %s; killing connection", err.Error()) diff --git a/stream.go b/stream.go index 987849f..f8cc27e 100644 --- a/stream.go +++ b/stream.go @@ -38,9 +38,7 @@ type Stream struct { // for later memory pool freeing exbuf []byte - deadlineLock sync.Mutex - wDeadline time.Time - rDeadline time.Time + rDeadline, wDeadline pipeDeadline clLock sync.Mutex closedRemote bool @@ -70,15 +68,7 @@ func (s *Stream) preloadData() { } } -func (s *Stream) waitForData(ctx context.Context) error { - s.deadlineLock.Lock() - if !s.rDeadline.IsZero() { - dctx, cancel := context.WithDeadline(ctx, s.rDeadline) - defer cancel() - ctx = dctx - } - s.deadlineLock.Unlock() - +func (s *Stream) waitForData() error { select { case <-s.reset: // This is the only place where it's safe to return these. @@ -91,8 +81,8 @@ func (s *Stream) waitForData(ctx context.Context) error { s.extra = read s.exbuf = read return nil - case <-ctx.Done(): - return ctx.Err() + case <-s.rDeadline.wait(): + return errTimeout } } @@ -125,7 +115,7 @@ func (s *Stream) Read(b []byte) (int, error) { default: } if s.extra == nil { - err := s.waitForData(context.Background()) + err := s.waitForData() if err != nil { return 0, err } @@ -172,21 +162,7 @@ func (s *Stream) write(b []byte) (int, error) { return 0, errors.New("cannot write to closed stream") } - s.deadlineLock.Lock() - wDeadlineCtx, cleanup := func(s *Stream) (context.Context, context.CancelFunc) { - if s.wDeadline.IsZero() { - return s.closedLocal, nil - } else { - return context.WithDeadline(s.closedLocal, s.wDeadline) - } - }(s) - s.deadlineLock.Unlock() - - err := s.mp.sendMsg(wDeadlineCtx, s.id.header(messageTag), b) - - if cleanup != nil { - cleanup() - } + err := s.mp.sendMsg(s.wDeadline.wait(), s.id.header(messageTag), b) if err != nil { if err == context.Canceled { @@ -206,7 +182,7 @@ func (s *Stream) Close() error { ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) defer cancel() - err := s.mp.sendMsg(ctx, s.id.header(closeTag), nil) + err := s.mp.sendMsg(ctx.Done(), s.id.header(closeTag), nil) if s.isClosed() { return nil @@ -219,6 +195,7 @@ func (s *Stream) Close() error { s.doCloseLocal() if remote { + s.cancelDeadlines() s.mp.chLock.Lock() delete(s.mp.channels, s.id) s.mp.chLock.Unlock() @@ -252,6 +229,7 @@ func (s *Stream) Reset() error { close(s.reset) s.doCloseLocal() s.closedRemote = true + s.cancelDeadlines() go s.mp.sendResetMsg(s.id.header(resetTag), true) @@ -264,24 +242,44 @@ func (s *Stream) Reset() error { return nil } +func (s *Stream) cancelDeadlines() { + s.rDeadline.set(time.Time{}) + s.wDeadline.set(time.Time{}) +} + func (s *Stream) SetDeadline(t time.Time) error { - s.deadlineLock.Lock() - defer s.deadlineLock.Unlock() - s.rDeadline = t - s.wDeadline = t + s.clLock.Lock() + defer s.clLock.Unlock() + + if s.closedRemote || s.isClosed() { + return errStreamClosed + } + + s.rDeadline.set(t) + s.wDeadline.set(t) return nil } func (s *Stream) SetReadDeadline(t time.Time) error { - s.deadlineLock.Lock() - defer s.deadlineLock.Unlock() - s.rDeadline = t + s.clLock.Lock() + defer s.clLock.Unlock() + + if s.closedRemote { + return errStreamClosed + } + + s.rDeadline.set(t) return nil } func (s *Stream) SetWriteDeadline(t time.Time) error { - s.deadlineLock.Lock() - defer s.deadlineLock.Unlock() - s.wDeadline = t + s.clLock.Lock() + defer s.clLock.Unlock() + + if s.isClosed() { + return errStreamClosed + } + + s.wDeadline.set(t) return nil }