Skip to content

Commit

Permalink
Make all the unit tests work. ITs are failing pending on another PR.
Browse files Browse the repository at this point in the history
  • Loading branch information
yihanzhen committed Apr 9, 2018
1 parent 75d1d0b commit d5ed3f2
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.api.client.util.BackOff;
import com.google.api.client.util.ExponentialBackOff;
import com.google.api.gax.paging.Page;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.pathtemplate.PathTemplate;
import com.google.cloud.BaseService;
import com.google.cloud.ByteArray;
Expand Down Expand Up @@ -828,7 +829,7 @@ public ReadContext singleUse() {

@Override
public ReadContext singleUse(TimestampBound bound) {
return setActive(new SingleReadContext(this, bound, rawGrpcRpc, defaultPrefetchChunks));
return setActive(new SingleReadContext(this, bound, gapicRpc, defaultPrefetchChunks));
}

@Override
Expand All @@ -839,7 +840,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction() {
@Override
public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) {
return setActive(
new SingleUseReadOnlyTransaction(this, bound, rawGrpcRpc, defaultPrefetchChunks));
new SingleUseReadOnlyTransaction(this, bound, gapicRpc, defaultPrefetchChunks));
}

@Override
Expand All @@ -850,12 +851,12 @@ public ReadOnlyTransaction readOnlyTransaction() {
@Override
public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) {
return setActive(
new MultiUseReadOnlyTransaction(this, bound, rawGrpcRpc, defaultPrefetchChunks));
new MultiUseReadOnlyTransaction(this, bound, gapicRpc, defaultPrefetchChunks));
}

@Override
public TransactionRunner readWriteTransaction() {
return setActive(new TransactionRunnerImpl(this, rawGrpcRpc, defaultPrefetchChunks));
return setActive(new TransactionRunnerImpl(this, gapicRpc, defaultPrefetchChunks));
}

@Override
Expand Down Expand Up @@ -1055,20 +1056,14 @@ ResultSet executeQueryInternalWithOptions(
new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, QUERY) {
@Override
CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken) {
GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks);
SpannerRpc.StreamingCall call =
rpc.executeQuery(
resumeToken == null
? request
: request.toBuilder().setResumeToken(resumeToken).build(),
stream.consumer(),
session.options);
// We get one message for free.
if (prefetchChunks > 1) {
call.request(prefetchChunks - 1);
}
stream.setCall(call);
return stream;
return new CloseableServerStreamIterator<PartialResultSet>(rpc.executeQuery(
resumeToken == null
? request
: request.toBuilder().setResumeToken(resumeToken).build(),
null,
session.options));

// let resume fail for now
}
};
return new GrpcResultSet(stream, this, queryMode);
Expand Down Expand Up @@ -1168,20 +1163,14 @@ ResultSet readInternalWithOptions(
new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, READ) {
@Override
CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken) {
GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks);
SpannerRpc.StreamingCall call =
rpc.read(
resumeToken == null
? request
: request.toBuilder().setResumeToken(resumeToken).build(),
stream.consumer(),
session.options);
// We get one message for free.
if (prefetchChunks > 1) {
call.request(prefetchChunks - 1);
}
stream.setCall(call);
return stream;
return new CloseableServerStreamIterator<PartialResultSet>(rpc.read(
resumeToken == null
? request
: request.toBuilder().setResumeToken(resumeToken).build(),
null,
session.options));

// let resume fail for now
}
};
GrpcResultSet resultSet =
Expand Down Expand Up @@ -2287,6 +2276,32 @@ interface CloseableIterator<T> extends Iterator<T> {
void close(@Nullable String message);
}

private static final class CloseableServerStreamIterator<T> implements CloseableIterator<T> {

private final ServerStream<T> stream;
private final Iterator<T> iterator;

public CloseableServerStreamIterator(ServerStream<T> stream) {
this.stream = stream;
this.iterator = stream.iterator();
}

@Override
public boolean hasNext() {
return iterator.hasNext();
}

@Override
public T next() {
return iterator.next();
}

@Override
public void close(@Nullable String message) {
stream.cancel();
}
}

/** Adapts a streaming read/query call into an iterator over partial result sets. */
@VisibleForTesting
static class GrpcStreamIterator extends AbstractIterator<PartialResultSet>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.grpc.GaxGrpcProperties;
import com.google.api.gax.grpc.GrpcCallContext;
import com.google.api.gax.grpc.GrpcTransportChannel;
import com.google.api.gax.rpc.ApiClientHeaderProvider;
import com.google.api.gax.rpc.FixedTransportChannelProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.pathtemplate.PathTemplate;
import com.google.cloud.ServiceOptions;
Expand Down Expand Up @@ -72,6 +74,7 @@
import com.google.spanner.v1.PartitionQueryRequest;
import com.google.spanner.v1.PartitionReadRequest;
import com.google.spanner.v1.PartitionResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.ReadRequest;
import com.google.spanner.v1.RollbackRequest;
import com.google.spanner.v1.Session;
Expand Down Expand Up @@ -335,15 +338,19 @@ public void deleteSession(String sessionName, @Nullable Map<Option, ?> options)
}

@Override
public StreamingCall read(
public ServerStream<PartialResultSet> read(
ReadRequest request, ResultStreamConsumer consumer, @Nullable Map<Option, ?> options) {
throw new UnsupportedOperationException("Not implemented yet.");
GrpcCallContext context = GrpcCallContext.createDefault()
.withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue());
return stub.streamingReadCallable().call(request, context);
}

@Override
public StreamingCall executeQuery(
public ServerStream<PartialResultSet> executeQuery(
ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map<Option, ?> options) {
throw new UnsupportedOperationException("Not implemented yet.");
GrpcCallContext context = GrpcCallContext.createDefault()
.withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue());
return stub.executeStreamingSqlCallable().call(request, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.api.gax.grpc.GaxGrpcProperties;
import com.google.api.gax.rpc.ApiClientHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.pathtemplate.PathTemplate;
import com.google.cloud.NoCredentials;
import com.google.cloud.ServiceOptions;
Expand Down Expand Up @@ -366,25 +367,15 @@ public void deleteSession(String sessionName, @Nullable Map<Option, ?> options)
}

@Override
public StreamingCall read(
public ServerStream<PartialResultSet> read(
ReadRequest request, ResultStreamConsumer consumer, @Nullable Map<Option, ?> options) {
return doStreamingCall(
SpannerGrpc.METHOD_STREAMING_READ,
request,
consumer,
request.getSession(),
Option.CHANNEL_HINT.getLong(options));
throw new UnsupportedOperationException("Not implemented: read");
}

@Override
public StreamingCall executeQuery(
public ServerStream<PartialResultSet> executeQuery(
ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map<Option, ?> options) {
return doStreamingCall(
SpannerGrpc.METHOD_EXECUTE_STREAMING_SQL,
request,
consumer,
request.getSession(),
Option.CHANNEL_HINT.getLong(options));
throw new UnsupportedOperationException("Not implemented: executeQuery");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.spanner.spi.v1;

import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.ServiceRpc;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.spi.v1.SpannerRpc.Option;
Expand Down Expand Up @@ -197,10 +198,10 @@ Session createSession(String databaseName, @Nullable Map<String, String> labels,

void deleteSession(String sessionName, @Nullable Map<Option, ?> options) throws SpannerException;

StreamingCall read(
ServerStream<PartialResultSet> read(
ReadRequest request, ResultStreamConsumer consumer, @Nullable Map<Option, ?> options);

StreamingCall executeQuery(
ServerStream<PartialResultSet> executeQuery(
ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map<Option, ?> options);

Transaction beginTransaction(BeginTransactionRequest request, @Nullable Map<Option, ?> options)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2018 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner;

import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.common.base.Preconditions;
import com.google.common.collect.Queues;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CancellationException;

public class ServerStreamingStashCallable<RequestT, ResponseT>
extends ServerStreamingCallable<RequestT, ResponseT> {
private List<ResponseT> responseList;

public ServerStreamingStashCallable() {
responseList = new ArrayList<>();
}

public ServerStreamingStashCallable(List<ResponseT> responseList) {
this.responseList = responseList;
}

@Override
public void call(
RequestT request, ResponseObserver<ResponseT> responseObserver, ApiCallContext context) {
Preconditions.checkNotNull(responseObserver);

StreamControllerStash<ResponseT> controller =
new StreamControllerStash<>(responseList, responseObserver);
controller.start();
}

// Minimal implementation of back pressure aware stream controller. Not threadsafe
private static class StreamControllerStash<ResponseT> implements StreamController {
final ResponseObserver<ResponseT> observer;
final Queue<ResponseT> queue;
boolean autoFlowControl = true;
long numPending;
Throwable error;
boolean delivering, closed;

public StreamControllerStash(
List<ResponseT> responseList, ResponseObserver<ResponseT> observer) {
this.observer = observer;
this.queue = Queues.newArrayDeque(responseList);
}

public void start() {
observer.onStart(this);
if (autoFlowControl) {
numPending = Integer.MAX_VALUE;
}
deliver();
}

@Override
public void disableAutoInboundFlowControl() {
autoFlowControl = false;
}

@Override
public void request(int count) {
numPending += count;
deliver();
}

@Override
public void cancel() {
error = new CancellationException("User cancelled stream");
deliver();
}

private void deliver() {
if (delivering || closed) return;
delivering = true;

try {
while (error == null && numPending > 0 && !queue.isEmpty()) {
numPending--;
observer.onResponse(queue.poll());
}

if (error != null || queue.isEmpty()) {
if (error != null) {
observer.onError(error);
} else {
observer.onComplete();
}
closed = true;
}
} finally {
delivering = false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;

import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.spi.v1.SpannerRpc;
import com.google.protobuf.ByteString;
Expand Down Expand Up @@ -280,18 +282,15 @@ public void request(int numMessages) {}
}

private void mockRead(final PartialResultSet myResultSet) {
final ArgumentCaptor<SpannerRpc.ResultStreamConsumer> consumer =
ArgumentCaptor.forClass(SpannerRpc.ResultStreamConsumer.class);
Mockito.when(rpc.read(Mockito.<ReadRequest>any(), consumer.capture(), Mockito.eq(options)))
.then(
new Answer<SpannerRpc.StreamingCall>() {
@Override
public SpannerRpc.StreamingCall answer(InvocationOnMock invocation) throws Throwable {
consumer.getValue().onPartialResultSet(myResultSet);
consumer.getValue().onCompleted();
return new NoOpStreamingCall();
}
});
ServerStreamingCallable<ReadRequest, PartialResultSet> serverStreamingCallable =
new ServerStreamingStashCallable(Arrays.<PartialResultSet>asList(myResultSet));
final ServerStream<PartialResultSet> mockServerStream = serverStreamingCallable.call(null);
Mockito.when(
rpc.read(
Mockito.<ReadRequest>any(),
Mockito.<SpannerRpc.ResultStreamConsumer>any(),
Mockito.eq(options)))
.thenReturn(mockServerStream);
}

@Test
Expand Down

0 comments on commit d5ed3f2

Please sign in to comment.