Skip to content

Commit

Permalink
[BUG] Serialization bugs can cause node drops (#1885)
Browse files Browse the repository at this point in the history
This commit restructures InboundHandler to ensure all data 
is consumed over the wire.

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
  • Loading branch information
reta committed Jan 14, 2022
1 parent e7d44c2 commit f059738
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 36 deletions.
136 changes: 103 additions & 33 deletions server/src/main/java/org/opensearch/transport/InboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.threadpool.ThreadPool;

import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -149,27 +150,13 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
streamInput = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(streamInput, header.getVersion());
if (header.isError()) {
handlerResponseError(streamInput, handler);
handlerResponseError(requestId, streamInput, handler);
} else {
handleResponse(remoteAddress, streamInput, handler);
}
// Check the entire message has been read
final int nextByte = streamInput.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ handler
+ "], error ["
+ header.isError()
+ "]; resetting"
);
handleResponse(requestId, remoteAddress, streamInput, handler);
}
} else {
assert header.isError() == false;
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, handler);
handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler);
}
}
}
Expand Down Expand Up @@ -246,22 +233,11 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
assertRemoteVersion(stream, header.getVersion());
final RequestHandlerRegistry<T> reg = requestHandlers.getHandler(action);
assert reg != null;
final T request = reg.newRequest(stream);

final T request = newRequest(requestId, action, stream, reg);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (request) for requestId ["
+ requestId
+ "], action ["
+ action
+ "], available ["
+ stream.available()
+ "]; resetting"
);
}
checkStreamIsFullyConsumed(requestId, action, stream);

final String executor = reg.getExecutor();
if (ThreadPool.Names.SAME.equals(executor)) {
try {
Expand All @@ -279,6 +255,97 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
}
}

/**
* Creates new request instance out of input stream. Throws IllegalStateException if the end of
* the stream was reached before the request is fully deserialized from the stream.
* @param <T> transport request type
* @param requestId request identifier
* @param action action name
* @param stream stream
* @param reg request handler registry
* @return new request instance
* @throws IOException IOException
* @throws IllegalStateException IllegalStateException
*/
private <T extends TransportRequest> T newRequest(
final long requestId,
final String action,
final StreamInput stream,
final RequestHandlerRegistry<T> reg
) throws IOException {
try {
return reg.newRequest(stream);
} catch (final EOFException e) {
// Another favor of (de)serialization issues is when stream contains less bytes than
// the request handler needs to deserialize the payload.
throw new IllegalStateException(
"Message fully read (request) but more data is expected for requestId ["
+ requestId
+ "], action ["
+ action
+ "]; resetting",
e
);
}
}

/**
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
* @param requestId request identifier
* @param action action name
* @param stream stream
* @throws IOException IOException
*/
private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException {
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
final int nextByte = stream.read();

// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (request) for requestId ["
+ requestId
+ "], action ["
+ action
+ "], available ["
+ stream.available()
+ "]; resetting"
);
}
}

/**
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
* @param requestId request identifier
* @param handler response handler
* @param stream stream
* @param error "true" if response represents error, "false" otherwise
* @throws IOException IOException
*/
private void checkStreamIsFullyConsumed(
final long requestId,
final TransportResponseHandler<?> handler,
final StreamInput stream,
final boolean error
) throws IOException {
if (stream != EMPTY_STREAM_INPUT) {
// Check the entire message has been read
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ handler
+ "], error ["
+ error
+ "]; resetting"
);
}
}
}

private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) {
try {
transportChannel.sendResponse(e);
Expand All @@ -289,6 +356,7 @@ private static void sendErrorResponse(String actionName, TransportChannel transp
}

private <T extends TransportResponse> void handleResponse(
final long requestId,
InetSocketAddress remoteAddress,
final StreamInput stream,
final TransportResponseHandler<T> handler
Expand All @@ -297,6 +365,7 @@ private <T extends TransportResponse> void handleResponse(
try {
response = handler.read(stream);
response.remoteAddress(new TransportAddress(remoteAddress));
checkStreamIsFullyConsumed(requestId, handler, stream, false);
} catch (Exception e) {
final Exception serializationException = new TransportSerializationException(
"Failed to deserialize response from handler [" + handler + "]",
Expand All @@ -322,10 +391,11 @@ private <T extends TransportResponse> void doHandleResponse(TransportResponseHan
}
}

private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler<?> handler) {
Exception error;
try {
error = stream.readException();
checkStreamIsFullyConsumed(requestId, handler, stream, true);
} catch (Exception e) {
error = new TransportSerializationException(
"Failed to deserialize exception response from stream for handler [" + handler + "]",
Expand Down
Loading

0 comments on commit f059738

Please sign in to comment.