diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index d031566d3927..95afcc7a5483 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -90,6 +90,13 @@ func (c *Core) startForwarding() error { go func() { defer shutdownWg.Done() + // closeCh is used to shutdown the spawned goroutines once this + // function returns + closeCh := make(chan struct{}) + defer func() { + close(closeCh) + }() + if c.logger.IsInfo() { c.logger.Info("core/startClusterListener: starting listener", "listener_address", laddr) } @@ -134,7 +141,6 @@ func (c *Core) startForwarding() error { if conn == nil { continue } - defer conn.Close() // Type assert to TLS connection and handshake to populate the // connection state @@ -159,11 +165,28 @@ func (c *Core) startForwarding() error { c.clusterParamsLock.RLock() rpcServer := c.rpcServer c.clusterParamsLock.RUnlock() + + shutdownWg.Add(2) + // quitCh is used to close the connection and the second + // goroutine if the server closes before closeCh. + quitCh := make(chan struct{}) + go func() { + select { + case <-quitCh: + case <-closeCh: + } + tlsConn.Close() + shutdownWg.Done() + }() + go func() { fws.ServeConn(tlsConn, &http2.ServeConnOpts{ Handler: rpcServer, }) - tlsConn.Close() + // close the quitCh which will close the connection and + // the other goroutine. + close(quitCh) + shutdownWg.Done() }() default: