diff --git a/transport/handler_server.go b/transport/handler_server.go index f1f6caf89c26..7e0fdb35938b 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -123,10 +123,9 @@ type serverHandlerTransport struct { // when WriteStatus is called. writes chan func() - mu sync.Mutex - // streamDone indicates whether WriteStatus has been called and writes channel - // has been closed. - streamDone bool + // block concurrent WriteStatus calls + // e.g. grpc/(*serverStream).SendMsg/RecvMsg + writeStatusMu sync.Mutex } func (ht *serverHandlerTransport) Close() error { @@ -177,13 +176,9 @@ func (ht *serverHandlerTransport) do(fn func()) error { } func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { - ht.mu.Lock() - if ht.streamDone { - ht.mu.Unlock() - return nil - } - ht.streamDone = true - ht.mu.Unlock() + ht.writeStatusMu.Lock() + defer ht.writeStatusMu.Unlock() + err := ht.do(func() { ht.writeCommonHeaders(s) @@ -222,7 +217,11 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro } } }) - close(ht.writes) + + if err == nil { // transport has not been closed + ht.Close() + close(ht.writes) + } return err } diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 06fe813ca912..8505e1a7fd38 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -391,9 +391,10 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { } } +// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that +// concurrent "WriteStatus"s do not panic writing to closed "writes" channel. func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) { - st := newHandleStreamTest(t) - handleStream := func(s *Stream) { + testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) { if want := "/service/foo.bar"; s.method != want { t.Errorf("stream method = %q; want %q", s.method, want) } @@ -408,9 +409,27 @@ func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) { }() } wg.Wait() - } + }) +} + +// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write" +// following "WriteStatus" does not panic writing to closed "writes" channel. +func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { + testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) { + if want := "/service/foo.bar"; s.method != want { + t.Errorf("stream method = %q; want %q", s.method, want) + } + st.bodyw.Close() // no body + + st.ht.WriteStatus(s, status.New(codes.OK, "")) + st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{}) + }) +} + +func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) { + st := newHandleStreamTest(t) st.ht.HandleStreams( - func(s *Stream) { go handleStream(s) }, + func(s *Stream) { go handleStream(st, s) }, func(ctx context.Context, method string) context.Context { return ctx }, ) }