diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java index 696a441e9b6..8ea355de6e9 100644 --- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java @@ -61,6 +61,12 @@ public T executor(@Nullable Executor executor) { return thisT(); } + @Override + public T callExecutor(ServerCallExecutorSupplier executorSupplier) { + delegate().callExecutor(executorSupplier); + return thisT(); + } + @Override public T addService(ServerServiceDefinition service) { delegate().addService(service); diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java index 4402edb7d92..9dd97790f27 100644 --- a/api/src/main/java/io/grpc/ServerBuilder.java +++ b/api/src/main/java/io/grpc/ServerBuilder.java @@ -74,6 +74,30 @@ public static ServerBuilder forPort(int port) { */ public abstract T executor(@Nullable Executor executor); + + /** + * Allows for defining a way to provide a custom executor to handle the server call. + * This executor is the result of calling + * {@link ServerCallExecutorSupplier#getExecutor(ServerCall, Metadata)} per RPC. + * + *

It's an optional parameter. If it is provided, the {@link #executor(Executor)} would still + * run necessary tasks before the {@link ServerCallExecutorSupplier} is ready to be called, then + * it switches over. + * + *

If it is provided, {@link #directExecutor()} optimization is disabled. But if calling + * {@link ServerCallExecutorSupplier} returns null, the server call is still handled by the + * default {@link #executor(Executor)} as a fallback. + * + * @param executorSupplier the server call executor provider + * @return this + * @since 1.39.0 + * + * */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8274") + public T callExecutor(ServerCallExecutorSupplier executorSupplier) { + return thisT(); + } + /** * Adds a service implementation to the handler registry. * diff --git a/api/src/main/java/io/grpc/ServerCallExecutorSupplier.java b/api/src/main/java/io/grpc/ServerCallExecutorSupplier.java new file mode 100644 index 00000000000..c990dc943e5 --- /dev/null +++ b/api/src/main/java/io/grpc/ServerCallExecutorSupplier.java @@ -0,0 +1,34 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.concurrent.Executor; +import javax.annotation.Nullable; + +/** + * Defines what executor handles the server call, based on each RPC call information at runtime. + * */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8274") +public interface ServerCallExecutorSupplier { + + /** + * Returns an executor to handle the server call. + * It should never throw. It should return null to fallback to the default executor. + * */ + @Nullable + Executor getExecutor(ServerCall call, Metadata metadata); +} diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 715161c0635..d20c60cd446 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -24,6 +24,7 @@ import io.grpc.HandlerRegistry; import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.ServerCallExecutorSupplier; import io.grpc.ServerInterceptor; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; @@ -67,6 +68,12 @@ public T directExecutor() { return thisT(); } + @Override + public T callExecutor(ServerCallExecutorSupplier executorSupplier) { + delegate().callExecutor(executorSupplier); + return thisT(); + } + @Override public T executor(@Nullable Executor executor) { delegate().executor(executor); diff --git a/core/src/main/java/io/grpc/internal/SerializingExecutor.java b/core/src/main/java/io/grpc/internal/SerializingExecutor.java index 6ac57635cfb..73133a339e4 100644 --- a/core/src/main/java/io/grpc/internal/SerializingExecutor.java +++ b/core/src/main/java/io/grpc/internal/SerializingExecutor.java @@ -59,7 +59,7 @@ private static AtomicHelper getAtomicHelper() { private static final int RUNNING = -1; /** Underlying executor that all submitted Runnable objects are run on. */ - private final Executor executor; + private Executor executor; /** A list of Runnables to be run in order. */ private final Queue runQueue = new ConcurrentLinkedQueue<>(); @@ -76,6 +76,15 @@ public SerializingExecutor(Executor executor) { this.executor = executor; } + /** + * Only call this from this SerializingExecutor Runnable, so that the executor is immediately + * visible to this SerializingExecutor executor. + * */ + public void setExecutor(Executor executor) { + Preconditions.checkNotNull(executor, "'executor' must not be null."); + this.executor = executor; + } + /** * Runs the given runnable strictly after all Runnables that were submitted * before it, and using the {@code executor} passed to the constructor. . @@ -118,7 +127,8 @@ private void schedule(@Nullable Runnable removable) { public void run() { Runnable r; try { - while ((r = runQueue.poll()) != null) { + Executor oldExecutor = executor; + while (oldExecutor == executor && (r = runQueue.poll()) != null ) { try { r.run(); } catch (RuntimeException e) { diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 21f13cf5b4d..6bfe2d38ab3 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -28,6 +28,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; @@ -46,6 +47,7 @@ import io.grpc.InternalServerInterceptors; import io.grpc.Metadata; import io.grpc.ServerCall; +import io.grpc.ServerCallExecutorSupplier; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.ServerMethodDefinition; @@ -125,6 +127,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume private final InternalChannelz channelz; private final CallTracer serverCallTracer; private final Deadline.Ticker ticker; + private final ServerCallExecutorSupplier executorSupplier; /** * Construct a server. @@ -159,6 +162,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume this.serverCallTracer = builder.callTracerFactory.create(); this.ticker = checkNotNull(builder.ticker, "ticker"); channelz.addServer(this); + this.executorSupplier = builder.executorSupplier; } /** @@ -469,11 +473,11 @@ private void streamCreatedInternal( final Executor wrappedExecutor; // This is a performance optimization that avoids the synchronization and queuing overhead // that comes with SerializingExecutor. - if (executor == directExecutor()) { + if (executorSupplier != null || executor != directExecutor()) { + wrappedExecutor = new SerializingExecutor(executor); + } else { wrappedExecutor = new SerializeReentrantCallsDirectExecutor(); stream.optimizeForDirectExecutor(); - } else { - wrappedExecutor = new SerializingExecutor(executor); } if (headers.containsKey(MESSAGE_ENCODING_KEY)) { @@ -499,30 +503,37 @@ private void streamCreatedInternal( final JumpToApplicationThreadServerStreamListener jumpListener = new JumpToApplicationThreadServerStreamListener( - wrappedExecutor, executor, stream, context, tag); + wrappedExecutor, executor, stream, context, tag); stream.setListener(jumpListener); - // Run in wrappedExecutor so jumpListener.setListener() is called before any callbacks - // are delivered, including any errors. Callbacks can still be triggered, but they will be - // queued. - - final class StreamCreated extends ContextRunnable { - StreamCreated() { + final SettableFuture> future = SettableFuture.create(); + // Run in serializing executor so jumpListener.setListener() is called before any callbacks + // are delivered, including any errors. MethodLookup() and HandleServerCall() are proactively + // queued before any callbacks are queued at serializing executor. + // MethodLookup() runs on the default executor. + // When executorSupplier is enabled, MethodLookup() may set/change the executor in the + // SerializingExecutor before it finishes running. + // Then HandleServerCall() and callbacks would switch to the executorSupplier executor. + // Otherwise, they all run on the default executor. + + final class MethodLookup extends ContextRunnable { + MethodLookup() { super(context); } @Override public void runInContext() { - PerfMark.startTask("ServerTransportListener$StreamCreated.startCall", tag); + PerfMark.startTask("ServerTransportListener$MethodLookup.startCall", tag); PerfMark.linkIn(link); try { runInternal(); } finally { - PerfMark.stopTask("ServerTransportListener$StreamCreated.startCall", tag); + PerfMark.stopTask("ServerTransportListener$MethodLookup.startCall", tag); } } private void runInternal() { - ServerStreamListener listener = NOOP_LISTENER; + ServerMethodDefinition wrapMethod; + ServerCallParameters callParams; try { ServerMethodDefinition method = registry.lookupMethod(methodName); if (method == null) { @@ -530,21 +541,82 @@ private void runInternal() { } if (method == null) { Status status = Status.UNIMPLEMENTED.withDescription( - "Method not found: " + methodName); + "Method not found: " + methodName); // TODO(zhangkun83): this error may be recorded by the tracer, and if it's kept in // memory as a map whose key is the method name, this would allow a misbehaving // client to blow up the server in-memory stats storage by sending large number of // distinct unimplemented method // names. (https://github.com/grpc/grpc-java/issues/2285) + jumpListener.setListener(NOOP_LISTENER); stream.close(status, new Metadata()); context.cancel(null); + future.cancel(false); return; } - listener = startCall(stream, methodName, method, headers, context, statsTraceCtx, tag); + wrapMethod = wrapMethod(stream, method, statsTraceCtx); + callParams = maySwitchExecutor(wrapMethod, stream, headers, context, tag); + future.set(callParams); } catch (Throwable t) { + jumpListener.setListener(NOOP_LISTENER); stream.close(Status.fromThrowable(t), new Metadata()); context.cancel(null); + future.cancel(false); throw t; + } + } + + private ServerCallParameters maySwitchExecutor( + final ServerMethodDefinition methodDef, + final ServerStream stream, + final Metadata headers, + final Context.CancellableContext context, + final Tag tag) { + final ServerCallImpl call = new ServerCallImpl<>( + stream, + methodDef.getMethodDescriptor(), + headers, + context, + decompressorRegistry, + compressorRegistry, + serverCallTracer, + tag); + if (executorSupplier != null) { + Executor switchingExecutor = executorSupplier.getExecutor(call, headers); + if (switchingExecutor != null) { + ((SerializingExecutor)wrappedExecutor).setExecutor(switchingExecutor); + } + } + return new ServerCallParameters<>(call, methodDef.getServerCallHandler()); + } + } + + final class HandleServerCall extends ContextRunnable { + HandleServerCall() { + super(context); + } + + @Override + public void runInContext() { + PerfMark.startTask("ServerTransportListener$HandleServerCall.startCall", tag); + PerfMark.linkIn(link); + try { + runInternal(); + } finally { + PerfMark.stopTask("ServerTransportListener$HandleServerCall.startCall", tag); + } + } + + private void runInternal() { + ServerStreamListener listener = NOOP_LISTENER; + if (future.isCancelled()) { + return; + } + try { + listener = startWrappedCall(methodName, Futures.getDone(future), headers); + } catch (Throwable ex) { + stream.close(Status.fromThrowable(ex), new Metadata()); + context.cancel(null); + throw new IllegalStateException(ex); } finally { jumpListener.setListener(listener); } @@ -568,7 +640,8 @@ public void cancelled(Context context) { } } - wrappedExecutor.execute(new StreamCreated()); + wrappedExecutor.execute(new MethodLookup()); + wrappedExecutor.execute(new HandleServerCall()); } private Context.CancellableContext createContext( @@ -593,9 +666,8 @@ private Context.CancellableContext createContext( } /** Never returns {@code null}. */ - private ServerStreamListener startCall(ServerStream stream, String fullMethodName, - ServerMethodDefinition methodDef, Metadata headers, - Context.CancellableContext context, StatsTraceContext statsTraceCtx, Tag tag) { + private ServerMethodDefinition wrapMethod(ServerStream stream, + ServerMethodDefinition methodDef, StatsTraceContext statsTraceCtx) { // TODO(ejona86): should we update fullMethodName to have the canonical path of the method? statsTraceCtx.serverCallStarted( new ServerCallInfoImpl<>( @@ -609,34 +681,31 @@ private ServerStreamListener startCall(ServerStream stream, String ServerMethodDefinition interceptedDef = methodDef.withServerCallHandler(handler); ServerMethodDefinition wMethodDef = binlog == null ? interceptedDef : binlog.wrapMethodDefinition(interceptedDef); - return startWrappedCall(fullMethodName, wMethodDef, stream, headers, context, tag); + return wMethodDef; + } + + private final class ServerCallParameters { + ServerCallImpl call; + ServerCallHandler callHandler; + + public ServerCallParameters(ServerCallImpl call, + ServerCallHandler callHandler) { + this.call = call; + this.callHandler = callHandler; + } } private ServerStreamListener startWrappedCall( String fullMethodName, - ServerMethodDefinition methodDef, - ServerStream stream, - Metadata headers, - Context.CancellableContext context, - Tag tag) { - - ServerCallImpl call = new ServerCallImpl<>( - stream, - methodDef.getMethodDescriptor(), - headers, - context, - decompressorRegistry, - compressorRegistry, - serverCallTracer, - tag); - - ServerCall.Listener listener = - methodDef.getServerCallHandler().startCall(call, headers); - if (listener == null) { + ServerCallParameters params, + Metadata headers) { + ServerCall.Listener callListener = + params.callHandler.startCall(params.call, headers); + if (callListener == null) { throw new NullPointerException( - "startCall() returned a null listener for method " + fullMethodName); + "startCall() returned a null listener for method " + fullMethodName); } - return call.newServerStreamListener(listener); + return params.call.newServerStreamListener(callListener); } } diff --git a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java index aafdc150fb2..277e476143d 100644 --- a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java @@ -32,6 +32,7 @@ import io.grpc.InternalChannelz; import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.ServerCallExecutorSupplier; import io.grpc.ServerInterceptor; import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; @@ -93,6 +94,8 @@ public static ServerBuilder forPort(int port) { @Nullable BinaryLog binlog; InternalChannelz channelz = InternalChannelz.instance(); CallTracer.Factory callTracerFactory = CallTracer.getDefaultFactory(); + @Nullable + ServerCallExecutorSupplier executorSupplier; /** * An interface to provide to provide transport specific information for the server. This method @@ -122,6 +125,12 @@ public ServerImplBuilder executor(@Nullable Executor executor) { return this; } + @Override + public ServerImplBuilder callExecutor(ServerCallExecutorSupplier executorSupplier) { + this.executorSupplier = checkNotNull(executorSupplier); + return this; + } + @Override public ServerImplBuilder addService(ServerServiceDefinition service) { registryBuilder.addService(checkNotNull(service, "service")); diff --git a/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java b/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java index a9879375929..7325cda73cc 100644 --- a/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java +++ b/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java @@ -20,14 +20,18 @@ import static org.junit.Assert.fail; import io.grpc.CallOptions; +import io.grpc.ClientCall; import io.grpc.ManagedChannel; import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import io.grpc.internal.AbstractTransportTest; import io.grpc.internal.GrpcUtil; @@ -37,6 +41,8 @@ import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TestMethodDescriptors; import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -131,4 +137,44 @@ public ServerCall.Listener startCall( assertEquals(s.getCause(), e.getCause()); } } + + @Test + public void methodNotFound() throws Exception { + server = null; + ServerServiceDefinition definition = ServerServiceDefinition.builder("service_foo") + .addMethod(TestMethodDescriptors.voidMethod(), new ServerCallHandler() { + @Override + public Listener startCall(ServerCall call, Metadata headers) { + return null; + } + }) + .build(); + Server failingServer = InProcessServerBuilder + .forName("nocall-service") + .addService(definition) + .directExecutor() + .build() + .start(); + grpcCleanupRule.register(failingServer); + ManagedChannel channel = InProcessChannelBuilder + .forName("nocall-service") + .propagateCauseWithStatus(true) + .build(); + grpcCleanupRule.register(channel); + MethodDescriptor nonMatchMethod = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setFullMethodName("Waiter/serve") + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) + .build(); + ClientCall call = channel.newCall(nonMatchMethod, CallOptions.DEFAULT); + try { + ClientCalls.futureUnaryCall(call, null).get(5, TimeUnit.SECONDS); + fail("Call should fail."); + } catch (ExecutionException ex) { + StatusRuntimeException s = (StatusRuntimeException)ex.getCause(); + assertEquals(s.getStatus().getCode(), Code.UNIMPLEMENTED); + } + } } diff --git a/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java b/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java index 55f40819299..a1a8146b7bb 100644 --- a/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java +++ b/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java @@ -209,6 +209,38 @@ public void run() { assertEquals(Arrays.asList(1, 2, 3), runs); } + @Test + public void switchable() { + final SerializingExecutor testExecutor = + new SerializingExecutor(MoreExecutors.directExecutor()); + testExecutor.execute(new Runnable() { + @Override + public void run() { + runs.add(1); + testExecutor.setExecutor(singleExecutor); + } + }); + testExecutor.execute(new AddToRuns(-2)); + assertThat(runs).isEqualTo(Arrays.asList(1)); + singleExecutor.drain(); + assertThat(runs).isEqualTo(Arrays.asList(1, -2)); + } + + @Test + public void notSwitch() { + executor.execute(new Runnable() { + @Override + public void run() { + runs.add(1); + executor.setExecutor(singleExecutor); + } + }); + executor.execute(new AddToRuns(-2)); + assertThat(runs).isEqualTo(Collections.emptyList()); + singleExecutor.drain(); + assertThat(runs).isEqualTo(Arrays.asList(1, -2)); + } + private static class SingleExecutor implements Executor { private Runnable runnable; diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 5b5f5384d30..2a9dbd5a1fe 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -65,6 +65,7 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallExecutorSupplier; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.ServerMethodDefinition; @@ -73,6 +74,7 @@ import io.grpc.ServerTransportFilter; import io.grpc.ServiceDescriptor; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StringMarshaller; import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener; import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder; @@ -458,6 +460,127 @@ public void methodNotFound() throws Exception { assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode()); } + + @Test + public void executorSupplierSameExecutorBasic() throws Exception { + builder.executorSupplier = new ServerCallExecutorSupplier() { + @Override + public Executor getExecutor(ServerCall call, Metadata metadata) { + return executor.getScheduledExecutorService(); + } + }; + basicExchangeSuccessful(); + } + + @Test + public void executorSupplierNullBasic() throws Exception { + builder.executorSupplier = new ServerCallExecutorSupplier() { + @Override + public Executor getExecutor(ServerCall call, Metadata metadata) { + return null; + } + }; + basicExchangeSuccessful(); + } + + @Test + @SuppressWarnings("unchecked") + public void executorSupplierSwitchExecutor() throws Exception { + SingleExecutor switchingExecutor = new SingleExecutor(); + ServerCallExecutorSupplier mockSupplier = mock(ServerCallExecutorSupplier.class); + when(mockSupplier.getExecutor(any(ServerCall.class), any(Metadata.class))) + .thenReturn(switchingExecutor); + builder.executorSupplier = mockSupplier; + final AtomicReference> callReference + = new AtomicReference<>(); + mutableFallbackRegistry.addService(ServerServiceDefinition.builder( + new ServiceDescriptor("Waiter", METHOD)) + .addMethod(METHOD, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, + Metadata headers) { + callReference.set(call); + return callListener; + } + }).build()); + + createAndStartServer(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext( + streamTracerFactories, "Waiter/serve", requestHeaders); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); + verify(stream).setListener(isA(ServerStreamListener.class)); + verify(stream, atLeast(1)).statsTraceContext(); + + assertEquals(1, executor.runDueTasks()); + verify(fallbackRegistry).lookupMethod("Waiter/serve", AUTHORITY); + verify(streamTracerFactory).newServerStreamTracer(eq("Waiter/serve"), same(requestHeaders)); + ArgumentCaptor> callCapture = ArgumentCaptor.forClass(ServerCall.class); + verify(mockSupplier).getExecutor(callCapture.capture(), eq(requestHeaders)); + + assertThat(switchingExecutor.runnable).isNotNull(); + assertEquals(0, executor.numPendingTasks()); + switchingExecutor.drain(); + ServerCall call = callReference.get(); + assertNotNull(call); + assertThat(call).isEqualTo(callCapture.getValue()); + } + + @Test + @SuppressWarnings("CheckReturnValue") + public void executorSupplierFutureNotSet() throws Exception { + builder.executorSupplier = new ServerCallExecutorSupplier() { + @Override + public Executor getExecutor(ServerCall call, Metadata metadata) { + throw new IllegalStateException("Yeah!"); + } + }; + doThrow(new IllegalStateException("Yeah")).doNothing() + .when(stream).close(any(Status.class), any(Metadata.class)); + final AtomicReference> callReference + = new AtomicReference<>(); + mutableFallbackRegistry.addService(ServerServiceDefinition.builder( + new ServiceDescriptor("Waiter", METHOD)) + .addMethod(METHOD, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, + Metadata headers) { + callReference.set(call); + return callListener; + } + }).build()); + + createAndStartServer(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext( + streamTracerFactories, "Waiter/serve", requestHeaders); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); + verify(stream).setListener(isA(ServerStreamListener.class)); + verify(stream, atLeast(1)).statsTraceContext(); + + assertEquals(1, executor.runDueTasks()); + verify(fallbackRegistry).lookupMethod("Waiter/serve", AUTHORITY); + assertThat(callReference.get()).isNull(); + verify(stream, times(2)).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getAllValues().get(1); + assertEquals(Code.UNKNOWN, status.getCode()); + assertThat(status.getCause() instanceof IllegalStateException); + } + @Test public void decompressorNotFound() throws Exception { String decompressorName = "NON_EXISTENT_DECOMPRESSOR"; @@ -1513,4 +1636,24 @@ public ListenableFuture getStats() { /** Allows more precise catch blocks than plain Error to avoid catching AssertionError. */ private static final class TestError extends Error {} + + private static class SingleExecutor implements Executor { + private Runnable runnable; + + @Override + public void execute(Runnable r) { + if (runnable != null) { + fail("Already have runnable scheduled"); + } + runnable = r; + } + + public void drain() { + if (runnable != null) { + Runnable r = runnable; + runnable = null; + r.run(); + } + } + } }