diff --git a/muxer/muxer.go b/muxer/muxer.go index 92aea3d5..b9e261e9 100644 --- a/muxer/muxer.go +++ b/muxer/muxer.go @@ -61,6 +61,7 @@ type Muxer struct { startChan chan bool doneChan chan bool waitGroup sync.WaitGroup + waitGroupMutex sync.Mutex protocolSenders map[uint16]map[ProtocolRole]chan *Segment protocolReceivers map[uint16]map[ProtocolRole]chan *Segment protocolReceiversMutex sync.Mutex @@ -89,7 +90,9 @@ func New(conn net.Conn) *Muxer { // We must do this to break out of pending Read() calls to shut down cleanly _ = m.conn.Close() // Wait for other goroutines to shutdown + m.waitGroupMutex.Lock() m.waitGroup.Wait() + m.waitGroupMutex.Unlock() // Close ErrorChan to signify to consumer that we're shutting down close(m.errorChan) }() @@ -136,11 +139,20 @@ func (m *Muxer) sendError(err error) { } // RegisterProtocol registers the provided protocol ID with the muxer. It returns a channel for sending, -// a channel for receiving, and a channel to know when the muxer is shutting down +// a channel for receiving, and a channel to know when the muxer is shutting down. If the muxer is shutting +// down, this function will return nil values. func (m *Muxer) RegisterProtocol( protocolId uint16, protocolRole ProtocolRole, ) (chan *Segment, chan *Segment, chan bool) { + m.waitGroupMutex.Lock() + defer m.waitGroupMutex.Unlock() + // Check for shutdown + select { + case <-m.doneChan: + return nil, nil, nil + default: + } // Generate channels senderChan := make(chan *Segment, 10) receiverChan := make(chan *Segment, 10) diff --git a/protocol/protocol.go b/protocol/protocol.go index a329af88..9a9c8315 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -126,6 +126,10 @@ func (p *Protocol) Start() { p.config.ProtocolId, muxerProtocolRole, ) + if p.muxerDoneChan == nil { + p.SendError(fmt.Errorf("could not register protocol with muxer")) + return + } // Create channels p.sendQueueChan = make(chan Message, 50)