Skip to content

Commit

Permalink
Compress and cache cluster state during validate join request (opense…
Browse files Browse the repository at this point in the history
…arch-project#7321)

* Compress and cache cluster state during validate join request

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Add changelog and license

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Add javadoc and correct styling

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Add new handler for sending compressed cluster state in validate join flow and refactor code

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Refactor util method

Signed-off-by: Aman Khare <amkhar@amazon.com>

* optimize imports

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Use cluster state version based cache instead of time based cache

Signed-off-by: Aman Khare <amkhar@amazon.com>

* style fix

Signed-off-by: Aman Khare <amkhar@amazon.com>

* fix styling 2

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Use concurrent hashmap instead of cache, add UT class for ClusterStateUtils

Signed-off-by: Aman Khare <amkhar@amazon.com>

* style fix

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Use AtomicReference instead of ConcurrentHashMap

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Use method overloading to simplify the caller code

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Resolve conflicts

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Change code structure to separate the flow for JoinHelper and PublicationTransportHelper

Signed-off-by: Aman Khare <amkhar@amazon.com>

* Remove unnecessary input.setVersion line

Co-authored-by: Andrew Ross <andrross@amazon.com>
Signed-off-by: Aman Khare <85096200+amkhar@users.noreply.github.com>

---------

Signed-off-by: Aman Khare <amkhar@amazon.com>
Signed-off-by: Aman Khare <85096200+amkhar@users.noreply.github.com>
Co-authored-by: Aman Khare <amkhar@amazon.com>
Co-authored-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
3 people authored and sandeshkr419 committed Jun 8, 2023
1 parent 9cf4d71 commit 4663e11
Show file tree
Hide file tree
Showing 9 changed files with 485 additions and 101 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add ZSTD compression for snapshotting ([#2996](https://github.com/opensearch-project/OpenSearch/pull/2996))
- Change `com.amazonaws.sdk.ec2MetadataServiceEndpointOverride` to `aws.ec2MetadataServiceEndpoint` ([7372](https://github.com/opensearch-project/OpenSearch/pull/7372/))
- Change `com.amazonaws.sdk.stsEndpointOverride` to `aws.stsEndpointOverride` ([7372](https://github.com/opensearch-project/OpenSearch/pull/7372/))
- Compress and cache cluster state during validate join request ([#7321](https://github.com/opensearch-project/OpenSearch/pull/7321))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.cluster.coordination;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.Version;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.compress.Compressor;
import org.opensearch.common.compress.CompressorFactory;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.transport.BytesTransportRequest;

import java.io.IOException;

/**
* A helper class to utilize the compressed stream.
*
* @opensearch.internal
*/
public final class CompressedStreamUtils {
private static final Logger logger = LogManager.getLogger(CompressedStreamUtils.class);

public static BytesReference createCompressedStream(Version version, CheckedConsumer<StreamOutput, IOException> outputConsumer)
throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(version);
outputConsumer.accept(stream);
}
final BytesReference serializedByteRef = bStream.bytes();
logger.trace("serialized writable object for node version [{}] with size [{}]", version, serializedByteRef.length());
return serializedByteRef;
}

public static StreamInput decompressBytes(BytesTransportRequest request, NamedWriteableRegistry namedWriteableRegistry)
throws IOException {
final Compressor compressor = CompressorFactory.compressor(request.bytes());
final StreamInput in;
if (compressor != null) {
in = new InputStreamStreamInput(compressor.threadLocalInputStream(request.bytes().streamInput()));
} else {
in = request.bytes().streamInput();
}
in.setVersion(request.version());
return new NamedWriteableAwareStreamInput(in, namedWriteableRegistry);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ public Coordinator(
this.onJoinValidators,
rerouteService,
nodeHealthService,
this::onNodeCommissionStatusChange
this::onNodeCommissionStatusChange,
namedWriteableRegistry
);
this.persistedStateSupplier = persistedStateSupplier;
this.noClusterManagerBlockService = new NoClusterManagerBlockService(settings, clusterSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.cluster.ClusterState;
Expand All @@ -49,7 +50,9 @@
import org.opensearch.cluster.routing.allocation.AllocationService;
import org.opensearch.cluster.service.ClusterManagerService;
import org.opensearch.common.Priority;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
Expand All @@ -58,6 +61,7 @@
import org.opensearch.monitor.StatusInfo;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.threadpool.ThreadPool.Names;
import org.opensearch.transport.BytesTransportRequest;
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportException;
Expand Down Expand Up @@ -98,6 +102,7 @@ public class JoinHelper {

public static final String JOIN_ACTION_NAME = "internal:cluster/coordination/join";
public static final String VALIDATE_JOIN_ACTION_NAME = "internal:cluster/coordination/join/validate";
public static final String VALIDATE_COMPRESSED_JOIN_ACTION_NAME = JOIN_ACTION_NAME + "/validate_compressed";
public static final String START_JOIN_ACTION_NAME = "internal:cluster/coordination/start_join";

// the timeout for Zen1 join attempts
Expand All @@ -122,6 +127,8 @@ public class JoinHelper {

private final Supplier<JoinTaskExecutor> joinTaskExecutorGenerator;
private final Consumer<Boolean> nodeCommissioned;
private final NamedWriteableRegistry namedWriteableRegistry;
private final AtomicReference<Tuple<Long, BytesReference>> serializedState = new AtomicReference<>();

JoinHelper(
Settings settings,
Expand All @@ -135,13 +142,16 @@ public class JoinHelper {
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
RerouteService rerouteService,
NodeHealthService nodeHealthService,
Consumer<Boolean> nodeCommissioned
Consumer<Boolean> nodeCommissioned,
NamedWriteableRegistry namedWriteableRegistry
) {
this.clusterManagerService = clusterManagerService;
this.transportService = transportService;
this.nodeHealthService = nodeHealthService;
this.joinTimeout = JOIN_TIMEOUT_SETTING.get(settings);
this.nodeCommissioned = nodeCommissioned;
this.namedWriteableRegistry = namedWriteableRegistry;

this.joinTaskExecutorGenerator = () -> new JoinTaskExecutor(settings, allocationService, logger, rerouteService) {

private final long term = currentTermSupplier.getAsLong();
Expand Down Expand Up @@ -208,22 +218,52 @@ public ClusterTasksResult<JoinTaskExecutor.Task> execute(ClusterState currentSta
ThreadPool.Names.GENERIC,
ValidateJoinRequest::new,
(request, channel, task) -> {
final ClusterState localState = currentStateSupplier.get();
if (localState.metadata().clusterUUIDCommitted()
&& localState.metadata().clusterUUID().equals(request.getState().metadata().clusterUUID()) == false) {
throw new CoordinationStateRejectedException(
"join validation on cluster state"
+ " with a different cluster uuid "
+ request.getState().metadata().clusterUUID()
+ " than local cluster uuid "
+ localState.metadata().clusterUUID()
+ ", rejecting"
);
}
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState()));
runJoinValidators(currentStateSupplier, request.getState(), joinValidators);
channel.sendResponse(Empty.INSTANCE);
}
);

transportService.registerRequestHandler(
VALIDATE_COMPRESSED_JOIN_ACTION_NAME,
ThreadPool.Names.GENERIC,
BytesTransportRequest::new,
(request, channel, task) -> {
handleCompressedValidateJoinRequest(currentStateSupplier, joinValidators, request);
channel.sendResponse(Empty.INSTANCE);
}
);

}

private void runJoinValidators(
Supplier<ClusterState> currentStateSupplier,
ClusterState incomingState,
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators
) {
final ClusterState localState = currentStateSupplier.get();
if (localState.metadata().clusterUUIDCommitted()
&& localState.metadata().clusterUUID().equals(incomingState.metadata().clusterUUID()) == false) {
throw new CoordinationStateRejectedException(
"join validation on cluster state"
+ " with a different cluster uuid "
+ incomingState.metadata().clusterUUID()
+ " than local cluster uuid "
+ localState.metadata().clusterUUID()
+ ", rejecting"
);
}
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), incomingState));
}

private void handleCompressedValidateJoinRequest(
Supplier<ClusterState> currentStateSupplier,
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
BytesTransportRequest request
) throws IOException {
try (StreamInput input = CompressedStreamUtils.decompressBytes(request, namedWriteableRegistry)) {
ClusterState incomingState = ClusterState.readFrom(input, transportService.getLocalNode());
runJoinValidators(currentStateSupplier, incomingState, joinValidators);
}
}

private JoinCallback transportJoinCallback(TransportRequest request, TransportChannel channel) {
Expand Down Expand Up @@ -407,12 +447,42 @@ public String executor() {
}

public void sendValidateJoinRequest(DiscoveryNode node, ClusterState state, ActionListener<TransportResponse.Empty> listener) {
transportService.sendRequest(
node,
VALIDATE_JOIN_ACTION_NAME,
new ValidateJoinRequest(state),
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
if (node.getVersion().before(Version.V_3_0_0)) {
transportService.sendRequest(
node,
VALIDATE_JOIN_ACTION_NAME,
new ValidateJoinRequest(state),
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
} else {
try {
final BytesReference bytes = serializedState.updateAndGet(cachedState -> {
if (cachedState == null || cachedState.v1() != state.version()) {
try {
return new Tuple<>(
state.version(),
CompressedStreamUtils.createCompressedStream(node.getVersion(), state::writeTo)
);
} catch (IOException e) {
// mandatory as AtomicReference doesn't rethrow IOException.
throw new RuntimeException(e);
}
} else {
return cachedState;
}
}).v2();
final BytesTransportRequest request = new BytesTransportRequest(bytes, node.getVersion());
transportService.sendRequest(
node,
VALIDATE_COMPRESSED_JOIN_ACTION_NAME,
request,
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
} catch (Exception e) {
logger.warn("error sending cluster state to {}", node);
listener.onFailure(e);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,8 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.compress.Compressor;
import org.opensearch.common.compress.CompressorFactory;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.BytesTransportRequest;
import org.opensearch.transport.TransportChannel;
Expand Down Expand Up @@ -168,17 +160,9 @@ public PublishClusterStateStats stats() {
}

private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportRequest request) throws IOException {
final Compressor compressor = CompressorFactory.compressor(request.bytes());
StreamInput in = request.bytes().streamInput();
try {
if (compressor != null) {
in = new InputStreamStreamInput(compressor.threadLocalInputStream(in));
}
in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry);
in.setVersion(request.version());
// If true we received full cluster state - otherwise diffs
try (StreamInput in = CompressedStreamUtils.decompressBytes(request, namedWriteableRegistry)) {
ClusterState incomingState;
if (in.readBoolean()) {
final ClusterState incomingState;
// Close early to release resources used by the de-compression as early as possible
try (StreamInput input = in) {
incomingState = ClusterState.readFrom(input, transportService.getLocalNode());
Expand All @@ -198,7 +182,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
incompatibleClusterStateDiffReceivedCount.incrementAndGet();
throw new IncompatibleClusterStateVersionException("have no local cluster state");
} else {
ClusterState incomingState;
try {
final Diff<ClusterState> diff;
// Close stream early to release resources used by the de-compression as early as possible
Expand All @@ -225,8 +208,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
return response;
}
}
} finally {
IOUtils.close(in);
}
}

Expand Down Expand Up @@ -254,13 +235,10 @@ public PublicationContext newPublicationContext(ClusterChangedEvent clusterChang
}

private static BytesReference serializeFullClusterState(ClusterState clusterState, Version nodeVersion) throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(nodeVersion);
final BytesReference serializedState = CompressedStreamUtils.createCompressedStream(nodeVersion, stream -> {
stream.writeBoolean(true);
clusterState.writeTo(stream);
}
final BytesReference serializedState = bStream.bytes();
});
logger.trace(
"serialized full cluster state version [{}] for node version [{}] with size [{}]",
clusterState.version(),
Expand All @@ -271,13 +249,10 @@ private static BytesReference serializeFullClusterState(ClusterState clusterStat
}

private static BytesReference serializeDiffClusterState(Diff<ClusterState> diff, Version nodeVersion) throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(nodeVersion);
return CompressedStreamUtils.createCompressedStream(nodeVersion, stream -> {
stream.writeBoolean(false);
diff.writeTo(stream);
}
return bStream.bytes();
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,8 @@ grant codeBase "file:${gradle.worker.jar}" {
grant {
// since the gradle test worker jar is on the test classpath, our tests should be able to read it
permission java.io.FilePermission "${gradle.worker.jar}", "read";
permission java.lang.RuntimePermission "accessDeclaredMembers";
permission java.lang.RuntimePermission "reflectionFactoryAccess";
permission java.lang.RuntimePermission "accessClassInPackage.sun.reflect";
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
};
Loading

0 comments on commit 4663e11

Please sign in to comment.