diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java
index 696a441e9b6..8ea355de6e9 100644
--- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java
+++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java
@@ -61,6 +61,12 @@ public T executor(@Nullable Executor executor) {
return thisT();
}
+ @Override
+ public T callExecutor(ServerCallExecutorSupplier executorSupplier) {
+ delegate().callExecutor(executorSupplier);
+ return thisT();
+ }
+
@Override
public T addService(ServerServiceDefinition service) {
delegate().addService(service);
diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java
index 4402edb7d92..9dd97790f27 100644
--- a/api/src/main/java/io/grpc/ServerBuilder.java
+++ b/api/src/main/java/io/grpc/ServerBuilder.java
@@ -74,6 +74,30 @@ public static ServerBuilder> forPort(int port) {
*/
public abstract T executor(@Nullable Executor executor);
+
+ /**
+ * Allows for defining a way to provide a custom executor to handle the server call.
+ * This executor is the result of calling
+ * {@link ServerCallExecutorSupplier#getExecutor(ServerCall, Metadata)} per RPC.
+ *
+ *
It's an optional parameter. If it is provided, the {@link #executor(Executor)} would still
+ * run necessary tasks before the {@link ServerCallExecutorSupplier} is ready to be called, then
+ * it switches over.
+ *
+ *
If it is provided, {@link #directExecutor()} optimization is disabled. But if calling
+ * {@link ServerCallExecutorSupplier} returns null, the server call is still handled by the
+ * default {@link #executor(Executor)} as a fallback.
+ *
+ * @param executorSupplier the server call executor provider
+ * @return this
+ * @since 1.39.0
+ *
+ * */
+ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8274")
+ public T callExecutor(ServerCallExecutorSupplier executorSupplier) {
+ return thisT();
+ }
+
/**
* Adds a service implementation to the handler registry.
*
diff --git a/api/src/main/java/io/grpc/ServerCallExecutorSupplier.java b/api/src/main/java/io/grpc/ServerCallExecutorSupplier.java
new file mode 100644
index 00000000000..c990dc943e5
--- /dev/null
+++ b/api/src/main/java/io/grpc/ServerCallExecutorSupplier.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+import java.util.concurrent.Executor;
+import javax.annotation.Nullable;
+
+/**
+ * Defines what executor handles the server call, based on each RPC call information at runtime.
+ * */
+@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8274")
+public interface ServerCallExecutorSupplier {
+
+ /**
+ * Returns an executor to handle the server call.
+ * It should never throw. It should return null to fallback to the default executor.
+ * */
+ @Nullable
+ Executor getExecutor(ServerCall call, Metadata metadata);
+}
diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
index 715161c0635..d20c60cd446 100644
--- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
+++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
@@ -24,6 +24,7 @@
import io.grpc.HandlerRegistry;
import io.grpc.Server;
import io.grpc.ServerBuilder;
+import io.grpc.ServerCallExecutorSupplier;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
@@ -67,6 +68,12 @@ public T directExecutor() {
return thisT();
}
+ @Override
+ public T callExecutor(ServerCallExecutorSupplier executorSupplier) {
+ delegate().callExecutor(executorSupplier);
+ return thisT();
+ }
+
@Override
public T executor(@Nullable Executor executor) {
delegate().executor(executor);
diff --git a/core/src/main/java/io/grpc/internal/SerializingExecutor.java b/core/src/main/java/io/grpc/internal/SerializingExecutor.java
index 6ac57635cfb..73133a339e4 100644
--- a/core/src/main/java/io/grpc/internal/SerializingExecutor.java
+++ b/core/src/main/java/io/grpc/internal/SerializingExecutor.java
@@ -59,7 +59,7 @@ private static AtomicHelper getAtomicHelper() {
private static final int RUNNING = -1;
/** Underlying executor that all submitted Runnable objects are run on. */
- private final Executor executor;
+ private Executor executor;
/** A list of Runnables to be run in order. */
private final Queue runQueue = new ConcurrentLinkedQueue<>();
@@ -76,6 +76,15 @@ public SerializingExecutor(Executor executor) {
this.executor = executor;
}
+ /**
+ * Only call this from this SerializingExecutor Runnable, so that the executor is immediately
+ * visible to this SerializingExecutor executor.
+ * */
+ public void setExecutor(Executor executor) {
+ Preconditions.checkNotNull(executor, "'executor' must not be null.");
+ this.executor = executor;
+ }
+
/**
* Runs the given runnable strictly after all Runnables that were submitted
* before it, and using the {@code executor} passed to the constructor. .
@@ -118,7 +127,8 @@ private void schedule(@Nullable Runnable removable) {
public void run() {
Runnable r;
try {
- while ((r = runQueue.poll()) != null) {
+ Executor oldExecutor = executor;
+ while (oldExecutor == executor && (r = runQueue.poll()) != null ) {
try {
r.run();
} catch (RuntimeException e) {
diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java
index 21f13cf5b4d..6bfe2d38ab3 100644
--- a/core/src/main/java/io/grpc/internal/ServerImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerImpl.java
@@ -28,6 +28,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes;
@@ -46,6 +47,7 @@
import io.grpc.InternalServerInterceptors;
import io.grpc.Metadata;
import io.grpc.ServerCall;
+import io.grpc.ServerCallExecutorSupplier;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
@@ -125,6 +127,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
private final InternalChannelz channelz;
private final CallTracer serverCallTracer;
private final Deadline.Ticker ticker;
+ private final ServerCallExecutorSupplier executorSupplier;
/**
* Construct a server.
@@ -159,6 +162,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
this.serverCallTracer = builder.callTracerFactory.create();
this.ticker = checkNotNull(builder.ticker, "ticker");
channelz.addServer(this);
+ this.executorSupplier = builder.executorSupplier;
}
/**
@@ -469,11 +473,11 @@ private void streamCreatedInternal(
final Executor wrappedExecutor;
// This is a performance optimization that avoids the synchronization and queuing overhead
// that comes with SerializingExecutor.
- if (executor == directExecutor()) {
+ if (executorSupplier != null || executor != directExecutor()) {
+ wrappedExecutor = new SerializingExecutor(executor);
+ } else {
wrappedExecutor = new SerializeReentrantCallsDirectExecutor();
stream.optimizeForDirectExecutor();
- } else {
- wrappedExecutor = new SerializingExecutor(executor);
}
if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
@@ -499,30 +503,37 @@ private void streamCreatedInternal(
final JumpToApplicationThreadServerStreamListener jumpListener
= new JumpToApplicationThreadServerStreamListener(
- wrappedExecutor, executor, stream, context, tag);
+ wrappedExecutor, executor, stream, context, tag);
stream.setListener(jumpListener);
- // Run in wrappedExecutor so jumpListener.setListener() is called before any callbacks
- // are delivered, including any errors. Callbacks can still be triggered, but they will be
- // queued.
-
- final class StreamCreated extends ContextRunnable {
- StreamCreated() {
+ final SettableFuture> future = SettableFuture.create();
+ // Run in serializing executor so jumpListener.setListener() is called before any callbacks
+ // are delivered, including any errors. MethodLookup() and HandleServerCall() are proactively
+ // queued before any callbacks are queued at serializing executor.
+ // MethodLookup() runs on the default executor.
+ // When executorSupplier is enabled, MethodLookup() may set/change the executor in the
+ // SerializingExecutor before it finishes running.
+ // Then HandleServerCall() and callbacks would switch to the executorSupplier executor.
+ // Otherwise, they all run on the default executor.
+
+ final class MethodLookup extends ContextRunnable {
+ MethodLookup() {
super(context);
}
@Override
public void runInContext() {
- PerfMark.startTask("ServerTransportListener$StreamCreated.startCall", tag);
+ PerfMark.startTask("ServerTransportListener$MethodLookup.startCall", tag);
PerfMark.linkIn(link);
try {
runInternal();
} finally {
- PerfMark.stopTask("ServerTransportListener$StreamCreated.startCall", tag);
+ PerfMark.stopTask("ServerTransportListener$MethodLookup.startCall", tag);
}
}
private void runInternal() {
- ServerStreamListener listener = NOOP_LISTENER;
+ ServerMethodDefinition, ?> wrapMethod;
+ ServerCallParameters, ?> callParams;
try {
ServerMethodDefinition, ?> method = registry.lookupMethod(methodName);
if (method == null) {
@@ -530,21 +541,82 @@ private void runInternal() {
}
if (method == null) {
Status status = Status.UNIMPLEMENTED.withDescription(
- "Method not found: " + methodName);
+ "Method not found: " + methodName);
// TODO(zhangkun83): this error may be recorded by the tracer, and if it's kept in
// memory as a map whose key is the method name, this would allow a misbehaving
// client to blow up the server in-memory stats storage by sending large number of
// distinct unimplemented method
// names. (https://github.com/grpc/grpc-java/issues/2285)
+ jumpListener.setListener(NOOP_LISTENER);
stream.close(status, new Metadata());
context.cancel(null);
+ future.cancel(false);
return;
}
- listener = startCall(stream, methodName, method, headers, context, statsTraceCtx, tag);
+ wrapMethod = wrapMethod(stream, method, statsTraceCtx);
+ callParams = maySwitchExecutor(wrapMethod, stream, headers, context, tag);
+ future.set(callParams);
} catch (Throwable t) {
+ jumpListener.setListener(NOOP_LISTENER);
stream.close(Status.fromThrowable(t), new Metadata());
context.cancel(null);
+ future.cancel(false);
throw t;
+ }
+ }
+
+ private ServerCallParameters maySwitchExecutor(
+ final ServerMethodDefinition methodDef,
+ final ServerStream stream,
+ final Metadata headers,
+ final Context.CancellableContext context,
+ final Tag tag) {
+ final ServerCallImpl call = new ServerCallImpl<>(
+ stream,
+ methodDef.getMethodDescriptor(),
+ headers,
+ context,
+ decompressorRegistry,
+ compressorRegistry,
+ serverCallTracer,
+ tag);
+ if (executorSupplier != null) {
+ Executor switchingExecutor = executorSupplier.getExecutor(call, headers);
+ if (switchingExecutor != null) {
+ ((SerializingExecutor)wrappedExecutor).setExecutor(switchingExecutor);
+ }
+ }
+ return new ServerCallParameters<>(call, methodDef.getServerCallHandler());
+ }
+ }
+
+ final class HandleServerCall extends ContextRunnable {
+ HandleServerCall() {
+ super(context);
+ }
+
+ @Override
+ public void runInContext() {
+ PerfMark.startTask("ServerTransportListener$HandleServerCall.startCall", tag);
+ PerfMark.linkIn(link);
+ try {
+ runInternal();
+ } finally {
+ PerfMark.stopTask("ServerTransportListener$HandleServerCall.startCall", tag);
+ }
+ }
+
+ private void runInternal() {
+ ServerStreamListener listener = NOOP_LISTENER;
+ if (future.isCancelled()) {
+ return;
+ }
+ try {
+ listener = startWrappedCall(methodName, Futures.getDone(future), headers);
+ } catch (Throwable ex) {
+ stream.close(Status.fromThrowable(ex), new Metadata());
+ context.cancel(null);
+ throw new IllegalStateException(ex);
} finally {
jumpListener.setListener(listener);
}
@@ -568,7 +640,8 @@ public void cancelled(Context context) {
}
}
- wrappedExecutor.execute(new StreamCreated());
+ wrappedExecutor.execute(new MethodLookup());
+ wrappedExecutor.execute(new HandleServerCall());
}
private Context.CancellableContext createContext(
@@ -593,9 +666,8 @@ private Context.CancellableContext createContext(
}
/** Never returns {@code null}. */
- private ServerStreamListener startCall(ServerStream stream, String fullMethodName,
- ServerMethodDefinition methodDef, Metadata headers,
- Context.CancellableContext context, StatsTraceContext statsTraceCtx, Tag tag) {
+ private ServerMethodDefinition,?> wrapMethod(ServerStream stream,
+ ServerMethodDefinition methodDef, StatsTraceContext statsTraceCtx) {
// TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
statsTraceCtx.serverCallStarted(
new ServerCallInfoImpl<>(
@@ -609,34 +681,31 @@ private ServerStreamListener startCall(ServerStream stream, String
ServerMethodDefinition interceptedDef = methodDef.withServerCallHandler(handler);
ServerMethodDefinition, ?> wMethodDef = binlog == null
? interceptedDef : binlog.wrapMethodDefinition(interceptedDef);
- return startWrappedCall(fullMethodName, wMethodDef, stream, headers, context, tag);
+ return wMethodDef;
+ }
+
+ private final class ServerCallParameters {
+ ServerCallImpl call;
+ ServerCallHandler callHandler;
+
+ public ServerCallParameters(ServerCallImpl call,
+ ServerCallHandler callHandler) {
+ this.call = call;
+ this.callHandler = callHandler;
+ }
}
private ServerStreamListener startWrappedCall(
String fullMethodName,
- ServerMethodDefinition methodDef,
- ServerStream stream,
- Metadata headers,
- Context.CancellableContext context,
- Tag tag) {
-
- ServerCallImpl call = new ServerCallImpl<>(
- stream,
- methodDef.getMethodDescriptor(),
- headers,
- context,
- decompressorRegistry,
- compressorRegistry,
- serverCallTracer,
- tag);
-
- ServerCall.Listener listener =
- methodDef.getServerCallHandler().startCall(call, headers);
- if (listener == null) {
+ ServerCallParameters params,
+ Metadata headers) {
+ ServerCall.Listener callListener =
+ params.callHandler.startCall(params.call, headers);
+ if (callListener == null) {
throw new NullPointerException(
- "startCall() returned a null listener for method " + fullMethodName);
+ "startCall() returned a null listener for method " + fullMethodName);
}
- return call.newServerStreamListener(listener);
+ return params.call.newServerStreamListener(callListener);
}
}
diff --git a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java
index aafdc150fb2..277e476143d 100644
--- a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java
+++ b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java
@@ -32,6 +32,7 @@
import io.grpc.InternalChannelz;
import io.grpc.Server;
import io.grpc.ServerBuilder;
+import io.grpc.ServerCallExecutorSupplier;
import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
@@ -93,6 +94,8 @@ public static ServerBuilder> forPort(int port) {
@Nullable BinaryLog binlog;
InternalChannelz channelz = InternalChannelz.instance();
CallTracer.Factory callTracerFactory = CallTracer.getDefaultFactory();
+ @Nullable
+ ServerCallExecutorSupplier executorSupplier;
/**
* An interface to provide to provide transport specific information for the server. This method
@@ -122,6 +125,12 @@ public ServerImplBuilder executor(@Nullable Executor executor) {
return this;
}
+ @Override
+ public ServerImplBuilder callExecutor(ServerCallExecutorSupplier executorSupplier) {
+ this.executorSupplier = checkNotNull(executorSupplier);
+ return this;
+ }
+
@Override
public ServerImplBuilder addService(ServerServiceDefinition service) {
registryBuilder.addService(checkNotNull(service, "service"));
diff --git a/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java b/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java
index a9879375929..7325cda73cc 100644
--- a/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java
+++ b/core/src/test/java/io/grpc/inprocess/InProcessTransportTest.java
@@ -20,14 +20,18 @@
import static org.junit.Assert.fail;
import io.grpc.CallOptions;
+import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
+import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
+import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.AbstractTransportTest;
import io.grpc.internal.GrpcUtil;
@@ -37,6 +41,8 @@
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.TestMethodDescriptors;
import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
@@ -131,4 +137,44 @@ public ServerCall.Listener startCall(
assertEquals(s.getCause(), e.getCause());
}
}
+
+ @Test
+ public void methodNotFound() throws Exception {
+ server = null;
+ ServerServiceDefinition definition = ServerServiceDefinition.builder("service_foo")
+ .addMethod(TestMethodDescriptors.voidMethod(), new ServerCallHandler() {
+ @Override
+ public Listener startCall(ServerCall call, Metadata headers) {
+ return null;
+ }
+ })
+ .build();
+ Server failingServer = InProcessServerBuilder
+ .forName("nocall-service")
+ .addService(definition)
+ .directExecutor()
+ .build()
+ .start();
+ grpcCleanupRule.register(failingServer);
+ ManagedChannel channel = InProcessChannelBuilder
+ .forName("nocall-service")
+ .propagateCauseWithStatus(true)
+ .build();
+ grpcCleanupRule.register(channel);
+ MethodDescriptor nonMatchMethod =
+ MethodDescriptor.newBuilder()
+ .setType(MethodDescriptor.MethodType.UNKNOWN)
+ .setFullMethodName("Waiter/serve")
+ .setRequestMarshaller(TestMethodDescriptors.voidMarshaller())
+ .setResponseMarshaller(TestMethodDescriptors.voidMarshaller())
+ .build();
+ ClientCall call = channel.newCall(nonMatchMethod, CallOptions.DEFAULT);
+ try {
+ ClientCalls.futureUnaryCall(call, null).get(5, TimeUnit.SECONDS);
+ fail("Call should fail.");
+ } catch (ExecutionException ex) {
+ StatusRuntimeException s = (StatusRuntimeException)ex.getCause();
+ assertEquals(s.getStatus().getCode(), Code.UNIMPLEMENTED);
+ }
+ }
}
diff --git a/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java b/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java
index 55f40819299..a1a8146b7bb 100644
--- a/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java
+++ b/core/src/test/java/io/grpc/internal/SerializingExecutorTest.java
@@ -209,6 +209,38 @@ public void run() {
assertEquals(Arrays.asList(1, 2, 3), runs);
}
+ @Test
+ public void switchable() {
+ final SerializingExecutor testExecutor =
+ new SerializingExecutor(MoreExecutors.directExecutor());
+ testExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ runs.add(1);
+ testExecutor.setExecutor(singleExecutor);
+ }
+ });
+ testExecutor.execute(new AddToRuns(-2));
+ assertThat(runs).isEqualTo(Arrays.asList(1));
+ singleExecutor.drain();
+ assertThat(runs).isEqualTo(Arrays.asList(1, -2));
+ }
+
+ @Test
+ public void notSwitch() {
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ runs.add(1);
+ executor.setExecutor(singleExecutor);
+ }
+ });
+ executor.execute(new AddToRuns(-2));
+ assertThat(runs).isEqualTo(Collections.emptyList());
+ singleExecutor.drain();
+ assertThat(runs).isEqualTo(Arrays.asList(1, -2));
+ }
+
private static class SingleExecutor implements Executor {
private Runnable runnable;
diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java
index 5b5f5384d30..2a9dbd5a1fe 100644
--- a/core/src/test/java/io/grpc/internal/ServerImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java
@@ -65,6 +65,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
+import io.grpc.ServerCallExecutorSupplier;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
@@ -73,6 +74,7 @@
import io.grpc.ServerTransportFilter;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
+import io.grpc.Status.Code;
import io.grpc.StringMarshaller;
import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener;
import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder;
@@ -458,6 +460,127 @@ public void methodNotFound() throws Exception {
assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode());
}
+
+ @Test
+ public void executorSupplierSameExecutorBasic() throws Exception {
+ builder.executorSupplier = new ServerCallExecutorSupplier() {
+ @Override
+ public Executor getExecutor(ServerCall call, Metadata metadata) {
+ return executor.getScheduledExecutorService();
+ }
+ };
+ basicExchangeSuccessful();
+ }
+
+ @Test
+ public void executorSupplierNullBasic() throws Exception {
+ builder.executorSupplier = new ServerCallExecutorSupplier() {
+ @Override
+ public Executor getExecutor(ServerCall call, Metadata metadata) {
+ return null;
+ }
+ };
+ basicExchangeSuccessful();
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void executorSupplierSwitchExecutor() throws Exception {
+ SingleExecutor switchingExecutor = new SingleExecutor();
+ ServerCallExecutorSupplier mockSupplier = mock(ServerCallExecutorSupplier.class);
+ when(mockSupplier.getExecutor(any(ServerCall.class), any(Metadata.class)))
+ .thenReturn(switchingExecutor);
+ builder.executorSupplier = mockSupplier;
+ final AtomicReference> callReference
+ = new AtomicReference<>();
+ mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
+ new ServiceDescriptor("Waiter", METHOD))
+ .addMethod(METHOD,
+ new ServerCallHandler() {
+ @Override
+ public ServerCall.Listener startCall(
+ ServerCall call,
+ Metadata headers) {
+ callReference.set(call);
+ return callListener;
+ }
+ }).build());
+
+ createAndStartServer();
+ ServerTransportListener transportListener
+ = transportServer.registerNewServerTransport(new SimpleServerTransport());
+ transportListener.transportReady(Attributes.EMPTY);
+ Metadata requestHeaders = new Metadata();
+ StatsTraceContext statsTraceCtx =
+ StatsTraceContext.newServerContext(
+ streamTracerFactories, "Waiter/serve", requestHeaders);
+ when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
+ transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
+ verify(stream).setListener(isA(ServerStreamListener.class));
+ verify(stream, atLeast(1)).statsTraceContext();
+
+ assertEquals(1, executor.runDueTasks());
+ verify(fallbackRegistry).lookupMethod("Waiter/serve", AUTHORITY);
+ verify(streamTracerFactory).newServerStreamTracer(eq("Waiter/serve"), same(requestHeaders));
+ ArgumentCaptor> callCapture = ArgumentCaptor.forClass(ServerCall.class);
+ verify(mockSupplier).getExecutor(callCapture.capture(), eq(requestHeaders));
+
+ assertThat(switchingExecutor.runnable).isNotNull();
+ assertEquals(0, executor.numPendingTasks());
+ switchingExecutor.drain();
+ ServerCall call = callReference.get();
+ assertNotNull(call);
+ assertThat(call).isEqualTo(callCapture.getValue());
+ }
+
+ @Test
+ @SuppressWarnings("CheckReturnValue")
+ public void executorSupplierFutureNotSet() throws Exception {
+ builder.executorSupplier = new ServerCallExecutorSupplier() {
+ @Override
+ public Executor getExecutor(ServerCall call, Metadata metadata) {
+ throw new IllegalStateException("Yeah!");
+ }
+ };
+ doThrow(new IllegalStateException("Yeah")).doNothing()
+ .when(stream).close(any(Status.class), any(Metadata.class));
+ final AtomicReference> callReference
+ = new AtomicReference<>();
+ mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
+ new ServiceDescriptor("Waiter", METHOD))
+ .addMethod(METHOD,
+ new ServerCallHandler() {
+ @Override
+ public ServerCall.Listener startCall(
+ ServerCall call,
+ Metadata headers) {
+ callReference.set(call);
+ return callListener;
+ }
+ }).build());
+
+ createAndStartServer();
+ ServerTransportListener transportListener
+ = transportServer.registerNewServerTransport(new SimpleServerTransport());
+ transportListener.transportReady(Attributes.EMPTY);
+ Metadata requestHeaders = new Metadata();
+ StatsTraceContext statsTraceCtx =
+ StatsTraceContext.newServerContext(
+ streamTracerFactories, "Waiter/serve", requestHeaders);
+ when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
+ transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
+ verify(stream).setListener(isA(ServerStreamListener.class));
+ verify(stream, atLeast(1)).statsTraceContext();
+
+ assertEquals(1, executor.runDueTasks());
+ verify(fallbackRegistry).lookupMethod("Waiter/serve", AUTHORITY);
+ assertThat(callReference.get()).isNull();
+ verify(stream, times(2)).close(statusCaptor.capture(), any(Metadata.class));
+ Status status = statusCaptor.getAllValues().get(1);
+ assertEquals(Code.UNKNOWN, status.getCode());
+ assertThat(status.getCause() instanceof IllegalStateException);
+ }
+
@Test
public void decompressorNotFound() throws Exception {
String decompressorName = "NON_EXISTENT_DECOMPRESSOR";
@@ -1513,4 +1636,24 @@ public ListenableFuture getStats() {
/** Allows more precise catch blocks than plain Error to avoid catching AssertionError. */
private static final class TestError extends Error {}
+
+ private static class SingleExecutor implements Executor {
+ private Runnable runnable;
+
+ @Override
+ public void execute(Runnable r) {
+ if (runnable != null) {
+ fail("Already have runnable scheduled");
+ }
+ runnable = r;
+ }
+
+ public void drain() {
+ if (runnable != null) {
+ Runnable r = runnable;
+ runnable = null;
+ r.run();
+ }
+ }
+ }
}