From 2e065b285618dd770e634e26c27886d9bc354a9c Mon Sep 17 00:00:00 2001 From: "Bergstrom, Ryan" Date: Thu, 17 Jun 2021 14:41:16 -0700 Subject: [PATCH 1/2] Defer writing headers until the first message stanza is sent --- .../main/java/io/grpc/kotlin/ServerCalls.kt | 8 ++- .../java/io/grpc/kotlin/ServerCallsTest.kt | 50 +++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt index 62d857ee..34a92ed2 100644 --- a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt +++ b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt @@ -211,8 +211,6 @@ object ServerCalls { call: ServerCall, implementation: (Flow) -> Flow ): ServerCall.Listener { - call.sendHeaders(GrpcMetadata()) - val readiness = Readiness { call.isReady } val requestsChannel = Channel(1) @@ -241,8 +239,14 @@ object ServerCalls { val rpcScope = CoroutineScope(context) rpcScope.async { val mutex = Mutex() + val headersSent = AtomicBoolean(false) // enforces only sending headers once val failure = runCatching { implementation(requests).collect { + if (headersSent.compareAndSet(false, true)) { + mutex.withLock { + call.sendHeaders(GrpcMetadata()) + } + } readiness.suspendUntilReady() mutex.withLock { call.sendMessage(it) } } diff --git a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt index a3d4e9a2..bc24d342 100644 --- a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt @@ -798,4 +798,54 @@ class ServerCallsTest : AbstractCallsTest() { ClientCalls.unaryRpc(channel, sayHelloMethod, helloRequest("")) ).isEqualTo(helloReply("Hello!")) } + + @Test + fun serverCallListenerDefersHeaders() = runBlocking { + val requestReceived = Job() + val responseReleased = Job() + val channel = makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + requestReceived.complete() + responseReleased.join() + helloReply("Hello, ${it.name}") + } + ) + + val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) + + val headersReceived = Job() + val responseReceived = CompletableDeferred() + val closeStatus = CompletableDeferred() + + call.start( + object : ClientCall.Listener() { + override fun onHeaders(headers: Metadata) { + headersReceived.complete() + } + + override fun onMessage(message: HelloReply) { + responseReceived.complete(message) + } + + override fun onClose(status: Status, trailers: Metadata) { + closeStatus.complete(status) + } + }, + Metadata() + ) + call.sendMessage(helloRequest("Bob")) + call.request(1) + call.halfClose() + // wait for the handler to begin + requestReceived.join() + delay(200) + // headers should not have been sent + assertThat(headersReceived.isCompleted).isFalse() + // release the handler + responseReleased.complete() + headersReceived.join() + assertThat(responseReceived.await()).isEqualTo(helloReply("Hello, Bob")) + assertThat(closeStatus.await().code).isEqualTo(Status.Code.OK) + } + } From beccead950953b641cbc1c10a5b07d8de1cb4a4a Mon Sep 17 00:00:00 2001 From: "Bergstrom, Ryan" Date: Thu, 30 Sep 2021 13:39:11 -0700 Subject: [PATCH 2/2] Send headers even when no response messages are emitted --- .../main/java/io/grpc/kotlin/ServerCalls.kt | 8 ++ .../java/io/grpc/kotlin/ServerCallsTest.kt | 89 +++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt index 34a92ed2..692983de 100644 --- a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt +++ b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt @@ -242,6 +242,7 @@ object ServerCalls { val headersSent = AtomicBoolean(false) // enforces only sending headers once val failure = runCatching { implementation(requests).collect { + // once we have a response message, check if we've sent headers yet - if not, do so if (headersSent.compareAndSet(false, true)) { mutex.withLock { call.sendHeaders(GrpcMetadata()) @@ -251,6 +252,13 @@ object ServerCalls { mutex.withLock { call.sendMessage(it) } } }.exceptionOrNull() + // check headers again once we're done collecting the response flow - if we received + // no elements or threw an exception, then we wouldn't have sent them + if (headersSent.compareAndSet(false, true)) { + mutex.withLock { + call.sendHeaders(GrpcMetadata()) + } + } val closeStatus = when (failure) { null -> Status.OK is CancellationException -> Status.CANCELLED.withCause(failure) diff --git a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt index bc24d342..ae040df5 100644 --- a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt @@ -848,4 +848,93 @@ class ServerCallsTest : AbstractCallsTest() { assertThat(closeStatus.await().code).isEqualTo(Status.Code.OK) } + @Test + fun serverCallListenerDefersHeadersOnException() = runBlocking { + val requestReceived = Job() + val responseReleased = Job() + val channel = makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + requestReceived.complete() + responseReleased.join() + throw StatusException(Status.INTERNAL.withDescription("no response frames")) + } + ) + + val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) + + val headersReceived = Job() + val closeStatus = CompletableDeferred() + + call.start( + object : ClientCall.Listener() { + override fun onHeaders(headers: Metadata) { + headersReceived.complete() + } + + override fun onClose(status: Status, trailers: Metadata) { + closeStatus.complete(status) + } + }, + Metadata() + ) + call.sendMessage(helloRequest("Bob")) + call.request(1) + call.halfClose() + // wait for the handler to begin + requestReceived.join() + delay(200) + // headers should not have been sent + assertThat(headersReceived.isCompleted).isFalse() + // release the handler + responseReleased.complete() + headersReceived.join() + val status = closeStatus.await() + assertThat(status.code).isEqualTo(Status.Code.INTERNAL) + assertThat(status.description).contains("no response frames") + } + + @Test + fun serverCallListenerDefersHeadersOnEmptyStream() = runBlocking { + val requestReceived = Job() + val responseReleased = Job() + val channel = makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + request -> flow { + requestReceived.complete() + responseReleased.join() + } + } + ) + + val call = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT) + + val headersReceived = Job() + val closeStatus = CompletableDeferred() + + call.start( + object : ClientCall.Listener() { + override fun onHeaders(headers: Metadata) { + headersReceived.complete() + } + + override fun onClose(status: Status, trailers: Metadata) { + closeStatus.complete(status) + } + }, + Metadata() + ) + call.sendMessage(multiHelloRequest("Bob", "Fred")) + call.request(1) + call.halfClose() + // wait for the handler to begin + requestReceived.join() + delay(200) + // headers should not have been sent + assertThat(headersReceived.isCompleted).isFalse() + // release the handler + responseReleased.complete() + headersReceived.join() + val status = closeStatus.await() + assertThat(status.code).isEqualTo(Status.Code.OK) + } }