diff --git a/internal/test/ouroboros_mock/connection.go b/internal/test/ouroboros_mock/connection.go index a769410b..04a2988d 100644 --- a/internal/test/ouroboros_mock/connection.go +++ b/internal/test/ouroboros_mock/connection.go @@ -19,6 +19,7 @@ import ( "fmt" "net" "reflect" + "sync" "time" "github.com/blinklabs-io/gouroboros/cbor" @@ -42,6 +43,8 @@ type Connection struct { conversation []ConversationEntry muxer *muxer.Muxer muxerRecvChan chan *muxer.Segment + doneChan chan any + onceClose sync.Once } // NewConnection returns a new Connection with the provided conversation entries @@ -51,6 +54,7 @@ func NewConnection( ) net.Conn { c := &Connection{ conversation: conversation, + doneChan: make(chan any), } c.conn, c.mockConn = net.Pipe() // Start a muxer on the mocked side of the connection @@ -91,14 +95,20 @@ func (c *Connection) Write(b []byte) (n int, err error) { // Close closes both sides of the connection. This is needed to satisfy the net.Conn interface func (c *Connection) Close() error { - c.muxer.Stop() - if err := c.conn.Close(); err != nil { - return err - } - if err := c.mockConn.Close(); err != nil { - return err - } - return nil + var retErr error + c.onceClose.Do(func() { + close(c.doneChan) + c.muxer.Stop() + if err := c.conn.Close(); err != nil { + retErr = err + return + } + if err := c.mockConn.Close(); err != nil { + retErr = err + return + } + }) + return retErr } // LocalAddr provides a proxy to the client-side connection's LocalAddr function. This is needed to satisfy the net.Conn interface @@ -128,6 +138,11 @@ func (c *Connection) SetWriteDeadline(t time.Time) error { func (c *Connection) asyncLoop() { for _, entry := range c.conversation { + select { + case <-c.doneChan: + return + default: + } switch entry.Type { case EntryTypeInput: if err := c.processInputEntry(entry); err != nil {