Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Serialization bugs can cause node drops #1885

Merged
merged 3 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 103 additions & 33 deletions server/src/main/java/org/opensearch/transport/InboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.threadpool.ThreadPool;

import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -149,27 +150,13 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
streamInput = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(streamInput, header.getVersion());
if (header.isError()) {
handlerResponseError(streamInput, handler);
handlerResponseError(requestId, streamInput, handler);
} else {
handleResponse(remoteAddress, streamInput, handler);
}
// Check the entire message has been read
final int nextByte = streamInput.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ handler
+ "], error ["
+ header.isError()
+ "]; resetting"
);
handleResponse(requestId, remoteAddress, streamInput, handler);
}
} else {
assert header.isError() == false;
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, handler);
handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler);
}
}
}
Expand Down Expand Up @@ -246,22 +233,11 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
assertRemoteVersion(stream, header.getVersion());
final RequestHandlerRegistry<T> reg = requestHandlers.getHandler(action);
assert reg != null;
final T request = reg.newRequest(stream);

final T request = newRequest(requestId, action, stream, reg);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (request) for requestId ["
+ requestId
+ "], action ["
+ action
+ "], available ["
+ stream.available()
+ "]; resetting"
);
}
checkStreamIsFullyConsumed(requestId, action, stream);

final String executor = reg.getExecutor();
if (ThreadPool.Names.SAME.equals(executor)) {
try {
Expand All @@ -279,6 +255,97 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
}
}

/**
* Creates new request instance out of input stream. Throws IllegalStateException if the end of
* the stream was reached before the request is fully deserialized from the stream.
* @param <T> transport request type
* @param requestId request identifier
* @param action action name
* @param stream stream
* @param reg request handler registry
* @return new request instance
* @throws IOException IOException
* @throws IllegalStateException IllegalStateException
*/
private <T extends TransportRequest> T newRequest(
final long requestId,
final String action,
final StreamInput stream,
final RequestHandlerRegistry<T> reg
) throws IOException {
try {
return reg.newRequest(stream);
} catch (final EOFException e) {
// Another favor of (de)serialization issues is when stream contains less bytes than
// the request handler needs to deserialize the payload.
throw new IllegalStateException(
"Message fully read (request) but more data is expected for requestId ["
+ requestId
+ "], action ["
+ action
+ "]; resetting",
e
);
}
}

/**
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
* @param requestId request identifier
* @param action action name
* @param stream stream
* @throws IOException IOException
*/
private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException {
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
final int nextByte = stream.read();

// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (request) for requestId ["
+ requestId
+ "], action ["
+ action
+ "], available ["
+ stream.available()
+ "]; resetting"
);
}
}

/**
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
* @param requestId request identifier
* @param handler response handler
* @param stream stream
* @param error "true" if response represents error, "false" otherwise
* @throws IOException IOException
*/
private void checkStreamIsFullyConsumed(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we also closing the stream on exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but not in the handler, afaik the streams for InboundMessage are closed within InboundPipeline

final long requestId,
final TransportResponseHandler<?> handler,
final StreamInput stream,
final boolean error
) throws IOException {
if (stream != EMPTY_STREAM_INPUT) {
// Check the entire message has been read
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ handler
+ "], error ["
+ error
+ "]; resetting"
);
}
}
}

private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) {
try {
transportChannel.sendResponse(e);
Expand All @@ -289,6 +356,7 @@ private static void sendErrorResponse(String actionName, TransportChannel transp
}

private <T extends TransportResponse> void handleResponse(
final long requestId,
InetSocketAddress remoteAddress,
final StreamInput stream,
final TransportResponseHandler<T> handler
Expand All @@ -297,6 +365,7 @@ private <T extends TransportResponse> void handleResponse(
try {
response = handler.read(stream);
response.remoteAddress(new TransportAddress(remoteAddress));
checkStreamIsFullyConsumed(requestId, handler, stream, false);
} catch (Exception e) {
final Exception serializationException = new TransportSerializationException(
"Failed to deserialize response from handler [" + handler + "]",
Expand All @@ -322,10 +391,11 @@ private <T extends TransportResponse> void doHandleResponse(TransportResponseHan
}
}

private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler<?> handler) {
Exception error;
try {
error = stream.readException();
checkStreamIsFullyConsumed(requestId, handler, stream, true);
} catch (Exception e) {
error = new TransportSerializationException(
"Failed to deserialize exception response from stream for handler [" + handler + "]",
Expand Down
Loading