Skip to content

Commit

Permalink
Merge pull request #8987 from gyuho/tls-shutdown
Browse files Browse the repository at this point in the history
embed: fix *grpc.Server panic on GracefulStop with TLS-enabled server
  • Loading branch information
gyuho committed Dec 8, 2017
2 parents fc2eecf + 9bd07c9 commit 015c04b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 29 deletions.
56 changes: 42 additions & 14 deletions embed/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
return
}
if !serving {
// errored before starting gRPC server for serveCtx.grpcServerC
// errored before starting gRPC server for serveCtx.serversC
for _, sctx := range e.sctxs {
close(sctx.grpcServerC)
close(sctx.serversC)
}
}
e.Close()
Expand Down Expand Up @@ -219,23 +219,35 @@ func (e *Etcd) Config() Config {
return e.cfg
}

// Close gracefully shuts down all servers/listeners.
// Client requests will be terminated with request timeout.
// After timeout, enforce remaning requests be closed immediately.
func (e *Etcd) Close() {
e.closeOnce.Do(func() { close(e.stopc) })

// close client requests with request timeout
timeout := 2 * time.Second
if e.Server != nil {
timeout = e.Server.Cfg.ReqTimeout()
}
for _, sctx := range e.sctxs {
for gs := range sctx.grpcServerC {
e.stopGRPCServer(gs)
for ss := range sctx.serversC {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
stopServers(ctx, ss)
cancel()
}
}

for _, sctx := range e.sctxs {
sctx.cancel()
}

for i := range e.Clients {
if e.Clients[i] != nil {
e.Clients[i].Close()
}
}

for i := range e.metricsListeners {
e.metricsListeners[i].Close()
}
Expand All @@ -255,25 +267,38 @@ func (e *Etcd) Close() {
}
}

func (e *Etcd) stopGRPCServer(gs *grpc.Server) {
timeout := 2 * time.Second
if e.Server != nil {
timeout = e.Server.Cfg.ReqTimeout()
func stopServers(ctx context.Context, ss *servers) {
shutdownNow := func() {
// first, close the http.Server
ss.http.Shutdown(ctx)
// then close grpc.Server; cancels all active RPCs
ss.grpc.Stop()
}

// do not grpc.Server.GracefulStop with TLS enabled etcd server
// See https://github.com/grpc/grpc-go/issues/1384#issuecomment-317124531
// and https://github.com/coreos/etcd/issues/8916
if ss.secure {
shutdownNow()
return
}

ch := make(chan struct{})
go func() {
defer close(ch)
// close listeners to stop accepting new connections,
// will block on any existing transports
gs.GracefulStop()
ss.grpc.GracefulStop()
}()

// wait until all pending RPCs are finished
select {
case <-ch:
case <-time.After(timeout):
case <-ctx.Done():
// took too long, manually close open transports
// e.g. watch streams
gs.Stop()
shutdownNow()

// concurrent GracefulStop should be interrupted
<-ch
}
Expand All @@ -297,7 +322,9 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
for i := range peers {
if peers[i] != nil && peers[i].close != nil {
plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String())
peers[i].close(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
peers[i].close(ctx)
cancel()
}
}
}()
Expand Down Expand Up @@ -334,6 +361,7 @@ func (e *Etcd) servePeers() (err error) {
return err
}
}

for _, p := range e.Peers {
gs := v3rpc.Server(e.Server, peerTLScfg)
m := cmux.New(p.Listener)
Expand All @@ -349,8 +377,8 @@ func (e *Etcd) servePeers() (err error) {
// gracefully shutdown http.Server
// close open listeners, idle connections
// until context cancel or time-out
e.stopGRPCServer(gs)
return srv.Shutdown(ctx)
stopServers(ctx, &servers{secure: peerTLScfg != nil, grpc: gs, http: srv})
return nil
}
}

Expand Down
21 changes: 13 additions & 8 deletions embed/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ type serveCtx struct {

userHandlers map[string]http.Handler
serviceRegister func(*grpc.Server)
grpcServerC chan *grpc.Server
serversC chan *servers
}

type servers struct {
secure bool
grpc *grpc.Server
http *http.Server
}

func newServeCtx() *serveCtx {
ctx, cancel := context.WithCancel(context.Background())
return &serveCtx{ctx: ctx, cancel: cancel, userHandlers: make(map[string]http.Handler),
grpcServerC: make(chan *grpc.Server, 2), // in case sctx.insecure,sctx.secure true
serversC: make(chan *servers, 2), // in case sctx.insecure,sctx.secure true
}
}

Expand All @@ -84,7 +90,6 @@ func (sctx *serveCtx) serve(

if sctx.insecure {
gs := v3rpc.Server(s, nil, gopts...)
sctx.grpcServerC <- gs
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
Expand All @@ -93,9 +98,7 @@ func (sctx *serveCtx) serve(
grpcl := m.Match(cmux.HTTP2())
go func() { errHandler(gs.Serve(grpcl)) }()

opts := []grpc.DialOption{
grpc.WithInsecure(),
}
opts := []grpc.DialOption{grpc.WithInsecure()}
gwmux, err := sctx.registerGateway(opts)
if err != nil {
return err
Expand All @@ -109,6 +112,8 @@ func (sctx *serveCtx) serve(
}
httpl := m.Match(cmux.HTTP1())
go func() { errHandler(srvhttp.Serve(httpl)) }()

sctx.serversC <- &servers{grpc: gs, http: srvhttp}
plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.l.Addr().String())
}

Expand All @@ -118,7 +123,6 @@ func (sctx *serveCtx) serve(
return tlsErr
}
gs := v3rpc.Server(s, tlscfg, gopts...)
sctx.grpcServerC <- gs
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
Expand Down Expand Up @@ -150,10 +154,11 @@ func (sctx *serveCtx) serve(
}
go func() { errHandler(srv.Serve(tlsl)) }()

sctx.serversC <- &servers{secure: true, grpc: gs, http: srv}
plog.Infof("serving client requests on %s", sctx.l.Addr().String())
}

close(sctx.grpcServerC)
close(sctx.serversC)
return m.Serve()
}

Expand Down
34 changes: 27 additions & 7 deletions integration/embed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestEmbedEtcd(t *testing.T) {
{werr: "expected IP"},
}

urls := newEmbedURLs(10)
urls := newEmbedURLs(false, 10)

// setup defaults
for i := range tests {
Expand Down Expand Up @@ -105,12 +105,19 @@ func TestEmbedEtcd(t *testing.T) {
}
}

// TestEmbedEtcdGracefulStop ensures embedded server stops
func TestEmbedEtcdGracefulStopSecure(t *testing.T) { testEmbedEtcdGracefulStop(t, true) }
func TestEmbedEtcdGracefulStopInsecure(t *testing.T) { testEmbedEtcdGracefulStop(t, false) }

// testEmbedEtcdGracefulStop ensures embedded server stops
// cutting existing transports.
func TestEmbedEtcdGracefulStop(t *testing.T) {
func testEmbedEtcdGracefulStop(t *testing.T, secure bool) {
cfg := embed.NewConfig()
if secure {
cfg.ClientTLSInfo = testTLSInfo
cfg.PeerTLSInfo = testTLSInfo
}

urls := newEmbedURLs(2)
urls := newEmbedURLs(secure, 2)
setupEmbedCfg(cfg, []url.URL{urls[0]}, []url.URL{urls[1]})

cfg.Dir = filepath.Join(os.TempDir(), fmt.Sprintf("embed-etcd"))
Expand All @@ -123,7 +130,16 @@ func TestEmbedEtcdGracefulStop(t *testing.T) {
}
<-e.Server.ReadyNotify() // wait for e.Server to join the cluster

cli, err := clientv3.New(clientv3.Config{Endpoints: []string{urls[0].String()}})
clientCfg := clientv3.Config{
Endpoints: []string{urls[0].String()},
}
if secure {
clientCfg.TLS, err = testTLSInfo.ClientConfig()
if err != nil {
t.Fatal(err)
}
}
cli, err := clientv3.New(clientCfg)
if err != nil {
t.Fatal(err)
}
Expand All @@ -146,9 +162,13 @@ func TestEmbedEtcdGracefulStop(t *testing.T) {
}
}

func newEmbedURLs(n int) (urls []url.URL) {
func newEmbedURLs(secure bool, n int) (urls []url.URL) {
scheme := "unix"
if secure {
scheme = "unixs"
}
for i := 0; i < n; i++ {
u, _ := url.Parse(fmt.Sprintf("unix://localhost:%d%06d", os.Getpid(), i))
u, _ := url.Parse(fmt.Sprintf("%s://localhost:%d%06d", scheme, os.Getpid(), i))
urls = append(urls, *u)
}
return urls
Expand Down

0 comments on commit 015c04b

Please sign in to comment.