Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defer writing headers until the first message stanza is sent #275

Merged
merged 6 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions stub/src/main/java/io/grpc/kotlin/ServerCalls.kt
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ object ServerCalls {
call: ServerCall<RequestT, ResponseT>,
implementation: (Flow<RequestT>) -> Flow<ResponseT>
): ServerCall.Listener<RequestT> {
call.sendHeaders(GrpcMetadata())

val readiness = Readiness { call.isReady }
val requestsChannel = Channel<RequestT>(1)

Expand Down Expand Up @@ -241,12 +239,26 @@ object ServerCalls {
val rpcScope = CoroutineScope(context)
rpcScope.async {
val mutex = Mutex()
val headersSent = AtomicBoolean(false) // enforces only sending headers once
val failure = runCatching {
implementation(requests).collect {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this breaks if the responses are empty. Could we at least also do the check after collection is complete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks. Added that and tests to catch that case.

// once we have a response message, check if we've sent headers yet - if not, do so
if (headersSent.compareAndSet(false, true)) {
mutex.withLock {
call.sendHeaders(GrpcMetadata())
}
}
readiness.suspendUntilReady()
mutex.withLock { call.sendMessage(it) }
}
}.exceptionOrNull()
// check headers again once we're done collecting the response flow - if we received
// no elements or threw an exception, then we wouldn't have sent them
if (headersSent.compareAndSet(false, true)) {
mutex.withLock {
call.sendHeaders(GrpcMetadata())
}
}
val closeStatus = when (failure) {
null -> Status.OK
is CancellationException -> Status.CANCELLED.withCause(failure)
Expand Down
139 changes: 139 additions & 0 deletions stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -798,4 +798,143 @@ class ServerCallsTest : AbstractCallsTest() {
ClientCalls.unaryRpc(channel, sayHelloMethod, helloRequest(""))
).isEqualTo(helloReply("Hello!"))
}

@Test
fun serverCallListenerDefersHeaders() = runBlocking {
val requestReceived = Job()
val responseReleased = Job()
val channel = makeChannel(
ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) {
requestReceived.complete()
responseReleased.join()
helloReply("Hello, ${it.name}")
}
)

val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT)

val headersReceived = Job()
val responseReceived = CompletableDeferred<HelloReply>()
val closeStatus = CompletableDeferred<Status>()

call.start(
object : ClientCall.Listener<HelloReply>() {
override fun onHeaders(headers: Metadata) {
headersReceived.complete()
}

override fun onMessage(message: HelloReply) {
responseReceived.complete(message)
}

override fun onClose(status: Status, trailers: Metadata) {
closeStatus.complete(status)
}
},
Metadata()
)
call.sendMessage(helloRequest("Bob"))
call.request(1)
call.halfClose()
// wait for the handler to begin
requestReceived.join()
delay(200)
// headers should not have been sent
assertThat(headersReceived.isCompleted).isFalse()
// release the handler
responseReleased.complete()
headersReceived.join()
assertThat(responseReceived.await()).isEqualTo(helloReply("Hello, Bob"))
assertThat(closeStatus.await().code).isEqualTo(Status.Code.OK)
}

@Test
fun serverCallListenerDefersHeadersOnException() = runBlocking {
val requestReceived = Job()
val responseReleased = Job()
val channel = makeChannel(
ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) {
requestReceived.complete()
responseReleased.join()
throw StatusException(Status.INTERNAL.withDescription("no response frames"))
}
)

val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT)

val headersReceived = Job()
val closeStatus = CompletableDeferred<Status>()

call.start(
object : ClientCall.Listener<HelloReply>() {
override fun onHeaders(headers: Metadata) {
headersReceived.complete()
}

override fun onClose(status: Status, trailers: Metadata) {
closeStatus.complete(status)
}
},
Metadata()
)
call.sendMessage(helloRequest("Bob"))
call.request(1)
call.halfClose()
// wait for the handler to begin
requestReceived.join()
delay(200)
// headers should not have been sent
assertThat(headersReceived.isCompleted).isFalse()
// release the handler
responseReleased.complete()
headersReceived.join()
val status = closeStatus.await()
assertThat(status.code).isEqualTo(Status.Code.INTERNAL)
assertThat(status.description).contains("no response frames")
}

@Test
fun serverCallListenerDefersHeadersOnEmptyStream() = runBlocking {
val requestReceived = Job()
val responseReleased = Job()
val channel = makeChannel(
ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) {
request -> flow {
requestReceived.complete()
responseReleased.join()
}
}
)

val call = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT)

val headersReceived = Job()
val closeStatus = CompletableDeferred<Status>()

call.start(
object : ClientCall.Listener<HelloReply>() {
override fun onHeaders(headers: Metadata) {
headersReceived.complete()
}

override fun onClose(status: Status, trailers: Metadata) {
closeStatus.complete(status)
}
},
Metadata()
)
call.sendMessage(multiHelloRequest("Bob", "Fred"))
call.request(1)
call.halfClose()
// wait for the handler to begin
requestReceived.join()
delay(200)
// headers should not have been sent
assertThat(headersReceived.isCompleted).isFalse()
// release the handler
responseReleased.complete()
headersReceived.join()
val status = closeStatus.await()
assertThat(status.code).isEqualTo(Status.Code.OK)
}
}