diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 680c9eba0b17..f6bac0e8a00d 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -960,7 +960,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } } if err := t.writeHeaderLocked(s); err != nil { - return status.Convert(err).Err() + switch e := err.(type) { + case ConnectionError: + return status.Error(codes.Unavailable, e.Desc) + default: + return status.Convert(err).Err() + } } return nil } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 21aff27db1df..ff27678294f1 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" ) @@ -2136,6 +2137,70 @@ func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) { } } +func (s) TestWriteHeaderConnectionError(t *testing.T) { + server, client, cancel := setUp(t, 0, notifyCall) + defer cancel() + defer server.stop() + + waitWhileTrue(t, func() (bool, error) { + server.mu.Lock() + defer server.mu.Unlock() + + if len(server.conns) == 0 { + return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") + } + return false, nil + }) + + server.mu.Lock() + + if len(server.conns) != 1 { + t.Fatalf("Server has %d connections from the client, want 1", len(server.conns)) + } + + // Get the server transport for the connecton to the client. + var serverTransport *http2Server + for k := range server.conns { + serverTransport = k.(*http2Server) + } + notifyChan := make(chan struct{}) + server.h.notify = notifyChan + server.mu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cstream, err := client.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("Client failed to create first stream. Err: %v", err) + } + + <-notifyChan // Wait for server stream to be established. + var sstream *Stream + // Access stream on the server. + serverTransport.mu.Lock() + for _, v := range serverTransport.activeStreams { + if v.id == cstream.id { + sstream = v + } + } + serverTransport.mu.Unlock() + if sstream == nil { + t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id) + } + + client.Close(fmt.Errorf("closed manually by test")) + + // Wait for server transport to be closed. + <-serverTransport.done + + // Write header on a closed server transport. + err = serverTransport.WriteHeader(sstream, metadata.MD{}) + st := status.Convert(err) + if st.Code() != codes.Unavailable { + t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable) + } +} + func (s) TestPingPong1B(t *testing.T) { runPingPongTest(t, 1) }