diff --git a/mux.go b/mux.go index 458a0e2..5568e68 100644 --- a/mux.go +++ b/mux.go @@ -109,7 +109,7 @@ func VerifyConfig(config *Config) error { // Server is used to initialize a new server-side connection. // There must be at most one server-side connection. If a nil config is // provided, the DefaultConfiguration will be used. -func Server(conn net.Conn, config *Config, mm MemoryManager) (*Session, error) { +func Server(conn net.Conn, config *Config, mm func() (MemoryManager, error)) (*Session, error) { if config == nil { config = DefaultConfig() } @@ -121,7 +121,7 @@ func Server(conn net.Conn, config *Config, mm MemoryManager) (*Session, error) { // Client is used to initialize a new client-side connection. // There must be at most one client-side connection. -func Client(conn net.Conn, config *Config, mm MemoryManager) (*Session, error) { +func Client(conn net.Conn, config *Config, mm func() (MemoryManager, error)) (*Session, error) { if config == nil { config = DefaultConfig() } diff --git a/session.go b/session.go index ac9b5ef..3f445fb 100644 --- a/session.go +++ b/session.go @@ -22,19 +22,24 @@ import ( // Memory is allocated: // 1. When opening / accepting a new stream. This uses the highest priority. // 2. When trying to increase the stream receive window. This uses a lower priority. +// This is a subset of the libp2p's resource manager ResourceScopeSpan interface. type MemoryManager interface { - // ReserveMemory reserves memory / buffer. ReserveMemory(size int, prio uint8) error + // ReleaseMemory explicitly releases memory previously reserved with ReserveMemory ReleaseMemory(size int) + + // Done ends the span and releases associated resources. + Done() } type nullMemoryManagerImpl struct{} func (n nullMemoryManagerImpl) ReserveMemory(size int, prio uint8) error { return nil } func (n nullMemoryManagerImpl) ReleaseMemory(size int) {} +func (n nullMemoryManagerImpl) Done() {} -var nullMemoryManager MemoryManager = &nullMemoryManagerImpl{} +var nullMemoryManager = &nullMemoryManagerImpl{} // Session is used to wrap a reliable ordered connection and to // multiplex it into multiple streams. @@ -65,7 +70,7 @@ type Session struct { // reader is a buffered reader reader io.Reader - memoryManager MemoryManager + newMemoryManager func() (MemoryManager, error) // pings is used to track inflight pings pingLock sync.Mutex @@ -120,31 +125,31 @@ type Session struct { } // newSession is used to construct a new session -func newSession(config *Config, conn net.Conn, client bool, readBuf int, memoryManager MemoryManager) *Session { +func newSession(config *Config, conn net.Conn, client bool, readBuf int, newMemoryManager func() (MemoryManager, error)) *Session { var reader io.Reader = conn if readBuf > 0 { reader = bufio.NewReaderSize(reader, readBuf) } - if memoryManager == nil { - memoryManager = nullMemoryManager + if newMemoryManager == nil { + newMemoryManager = func() (MemoryManager, error) { return nullMemoryManager, nil } } s := &Session{ - config: config, - client: client, - logger: log.New(config.LogOutput, "", log.LstdFlags), - conn: conn, - reader: reader, - streams: make(map[uint32]*Stream), - inflight: make(map[uint32]struct{}), - synCh: make(chan struct{}, config.AcceptBacklog), - acceptCh: make(chan *Stream, config.AcceptBacklog), - sendCh: make(chan []byte, 64), - pongCh: make(chan uint32, config.PingBacklog), - pingCh: make(chan uint32), - recvDoneCh: make(chan struct{}), - sendDoneCh: make(chan struct{}), - shutdownCh: make(chan struct{}), - memoryManager: memoryManager, + config: config, + client: client, + logger: log.New(config.LogOutput, "", log.LstdFlags), + conn: conn, + reader: reader, + streams: make(map[uint32]*Stream), + inflight: make(map[uint32]struct{}), + synCh: make(chan struct{}, config.AcceptBacklog), + acceptCh: make(chan *Stream, config.AcceptBacklog), + sendCh: make(chan []byte, 64), + pongCh: make(chan uint32, config.PingBacklog), + pingCh: make(chan uint32), + recvDoneCh: make(chan struct{}), + sendDoneCh: make(chan struct{}), + shutdownCh: make(chan struct{}), + newMemoryManager: newMemoryManager, } if client { s.nextStreamID = 1 @@ -211,7 +216,11 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { return nil, s.shutdownErr } - if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil { + span, err := s.newMemoryManager() + if err != nil { + return nil, fmt.Errorf("failed to create resource scope span: %w", err) + } + if err := span.ReserveMemory(initialStreamWindow, 255); err != nil { return nil, err } @@ -219,6 +228,7 @@ GET_ID: // Get an ID, and check for stream exhaustion id := atomic.LoadUint32(&s.nextStreamID) if id >= math.MaxUint32-1 { + span.Done() return nil, ErrStreamsExhausted } if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { @@ -226,7 +236,7 @@ GET_ID: } // Register the stream - stream := newStream(s, id, streamInit, initialStreamWindow) + stream := newStream(s, id, streamInit, initialStreamWindow, span) s.streamLock.Lock() s.streams[id] = stream s.inflight[id] = struct{}{} @@ -234,6 +244,7 @@ GET_ID: // Send the window update to create if err := stream.sendWindowUpdate(); err != nil { + defer span.Done() select { case <-s.synCh: default: @@ -293,14 +304,10 @@ func (s *Session) Close() error { s.streamLock.Lock() defer s.streamLock.Unlock() - var memory int for id, stream := range s.streams { - memory += stream.memory stream.forceClose() delete(s.streams, id) - } - if memory > 0 { - s.memoryManager.ReleaseMemory(memory) + stream.memorySpan.Done() } return nil } @@ -781,10 +788,14 @@ func (s *Session) incomingStream(id uint32) error { } // Allocate a new stream - if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil { + span, err := s.newMemoryManager() + if err != nil { + return fmt.Errorf("failed to create resource span: %w", err) + } + if err := span.ReserveMemory(initialStreamWindow, 255); err != nil { return err } - stream := newStream(s, id, streamSYNReceived, initialStreamWindow) + stream := newStream(s, id, streamSYNReceived, initialStreamWindow, span) s.streamLock.Lock() defer s.streamLock.Unlock() @@ -795,14 +806,14 @@ func (s *Session) incomingStream(id uint32) error { if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } - s.memoryManager.ReleaseMemory(stream.memory) + span.Done() return ErrDuplicateStream } if s.numIncomingStreams >= s.config.MaxIncomingStreams { // too many active streams at the same time s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset") - s.memoryManager.ReleaseMemory(stream.memory) + defer span.Done() hdr := encode(typeWindowUpdate, flagRST, id, 0) return s.sendMsg(hdr, nil, nil) } @@ -817,6 +828,7 @@ func (s *Session) incomingStream(id uint32) error { return nil default: // Backlog exceeded! RST the stream + defer span.Done() s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset") s.deleteStream(id) hdr := encode(typeWindowUpdate, flagRST, id, 0) @@ -855,8 +867,8 @@ func (s *Session) deleteStream(id uint32) { s.numIncomingStreams-- } } - s.memoryManager.ReleaseMemory(str.memory) delete(s.streams, id) + str.memorySpan.Done() } // establishStream is used to mark a stream that was in the diff --git a/stream.go b/stream.go index 9175268..ac39577 100644 --- a/stream.go +++ b/stream.go @@ -31,7 +31,7 @@ const ( type Stream struct { sendWindow uint32 - memory int + memorySpan MemoryManager id uint32 session *Session @@ -53,15 +53,15 @@ type Stream struct { // newStream is used to construct a new stream within a given session for an ID. // It assumes that a memory allocation has been obtained for the initialWindow. -func newStream(session *Session, id uint32, state streamState, initialWindow uint32) *Stream { +func newStream(session *Session, id uint32, state streamState, initialWindow uint32, memorySpan MemoryManager) *Stream { s := &Stream{ id: id, session: session, state: state, sendWindow: initialStreamWindow, - memory: int(initialWindow), readDeadline: makePipeDeadline(), writeDeadline: makePipeDeadline(), + memorySpan: memorySpan, // Initialize the recvBuf with initialStreamWindow, not config.InitialStreamWindowSize. // The peer isn't allowed to send more data than initialStreamWindow until we've sent // the first window update (which will grant it up to config.InitialStreamWindowSize). @@ -229,9 +229,8 @@ func (s *Stream) sendWindowUpdate() error { } if recvWindow > s.recvWindow { grow := recvWindow - s.recvWindow - if err := s.session.memoryManager.ReserveMemory(int(grow), 128); err == nil { + if err := s.memorySpan.ReserveMemory(int(grow), 128); err == nil { s.recvWindow = recvWindow - s.memory += int(grow) _, delta = s.recvBuf.GrowTo(s.recvWindow, true) } }