Skip to content

Commit

Permalink
Send the CLIENT SETINFO command in a fire-and-forget way (#2823)
Browse files Browse the repository at this point in the history
* Send the CLIENT SETINFO command in a fire-and-forget way

* Attempt to improve method name; handle case where none of the metadata is provided

* Adding a debug message to indicate that the metadata was dropped in case this happens
  • Loading branch information
tishun authored Apr 10, 2024
1 parent c603c81 commit 764fdf3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 24 deletions.
66 changes: 42 additions & 24 deletions src/main/java/io/lettuce/core/RedisHandshake.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import io.lettuce.core.protocol.ConnectionInitializer;
import io.lettuce.core.protocol.ProtocolVersion;
import io.netty.channel.Channel;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

/**
* Redis RESP2/RESP3 handshake using the configured {@link ProtocolVersion} and other options for connection initialization and
Expand All @@ -49,7 +51,7 @@
*/
class RedisHandshake implements ConnectionInitializer {

private static final RedisVersion CLIENT_SET_INFO_SINCE = RedisVersion.of("7.2");
private static final InternalLogger LOG = InternalLoggerFactory.getInstance(RedisHandshake.class);

private final RedisCommandBuilder<String, String> commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);

Expand Down Expand Up @@ -99,8 +101,21 @@ public CompletionStage<Void> initialize(Channel channel) {
new RedisConnectionException("Protocol version" + this.requestedProtocolVersion + " not supported"));
}

return handshake.thenCompose(
ignore -> applyPostHandshake(channel, connectionState.getRedisVersion(), getNegotiatedProtocolVersion()));
// post-handshake commands, whose execution failures would cause the connection to be considered
// unsuccessfully established
CompletableFuture<Void> postHandshake = applyPostHandshake(channel);

// post-handshake commands, executed in a 'fire and forget' manner, to avoid having to react to different
// implementations or versions of the server runtime, and whose execution result (whether a success or a
// failure ) should not alter the outcome of the connection attempt
CompletableFuture<Void> connectionMetadata = applyConnectionMetadata(channel).handle((result, error) -> {
if (error != null) {
LOG.debug("Error applying connection metadata", error);
}
return null;
});

return handshake.thenCompose(ignore -> postHandshake).thenCompose(ignore -> connectionMetadata);
}

private CompletionStage<?> tryHandshakeResp3(Channel channel) {
Expand Down Expand Up @@ -237,41 +252,44 @@ private AsyncCommand<String, String, Map<String, Object>> dispatchHello(Channel
return dispatch(channel, this.commandBuilder.hello(3, null, null, connectionState.getClientName()));
}

private CompletableFuture<Void> applyPostHandshake(Channel channel, String redisVersion,
ProtocolVersion negotiatedProtocolVersion) {
private CompletableFuture<Void> applyPostHandshake(Channel channel) {

List<AsyncCommand<?, ?, ?>> postHandshake = new ArrayList<>();

ConnectionMetadata metadata = connectionState.getConnectionMetadata();
if (connectionState.getDb() > 0) {
postHandshake.add(new AsyncCommand<>(this.commandBuilder.select(connectionState.getDb())));
}

if (metadata.getClientName() != null && negotiatedProtocolVersion == ProtocolVersion.RESP2) {
postHandshake.add(new AsyncCommand<>(this.commandBuilder.clientSetname(connectionState.getClientName())));
if (connectionState.isReadOnly()) {
postHandshake.add(new AsyncCommand<>(this.commandBuilder.readOnly()));
}

if (negotiatedProtocolVersion == ProtocolVersion.RESP3) {
if (postHandshake.isEmpty()) {
return CompletableFuture.completedFuture(null);
}

RedisVersion currentVersion = RedisVersion.of(redisVersion);
return dispatch(channel, postHandshake);
}

if (currentVersion.isGreaterThanOrEqualTo(CLIENT_SET_INFO_SINCE)) {
private CompletableFuture<Void> applyConnectionMetadata(Channel channel) {

if (LettuceStrings.isNotEmpty(metadata.getLibraryName())) {
postHandshake
.add(new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-name", metadata.getLibraryName())));
}
List<AsyncCommand<?, ?, ?>> postHandshake = new ArrayList<>();

if (LettuceStrings.isNotEmpty(metadata.getLibraryVersion())) {
postHandshake.add(
new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-ver", metadata.getLibraryVersion())));
}
}
ConnectionMetadata metadata = connectionState.getConnectionMetadata();
ProtocolVersion negotiatedProtocolVersion = getNegotiatedProtocolVersion();

if (metadata.getClientName() != null && negotiatedProtocolVersion == ProtocolVersion.RESP2) {
postHandshake.add(new AsyncCommand<>(this.commandBuilder.clientSetname(connectionState.getClientName())));
}

if (connectionState.getDb() > 0) {
postHandshake.add(new AsyncCommand<>(this.commandBuilder.select(connectionState.getDb())));
if (LettuceStrings.isNotEmpty(metadata.getLibraryName())) {
postHandshake.add(
new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-name", metadata.getLibraryName())));
}

if (connectionState.isReadOnly()) {
postHandshake.add(new AsyncCommand<>(this.commandBuilder.readOnly()));
if (LettuceStrings.isNotEmpty(metadata.getLibraryVersion())) {
postHandshake.add(
new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-ver", metadata.getLibraryVersion())));
}

if (postHandshake.isEmpty()) {
Expand Down
32 changes: 32 additions & 0 deletions src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import static org.assertj.core.api.Assertions.*;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionStage;

import org.junit.jupiter.api.Test;

Expand All @@ -19,6 +21,8 @@
*/
class RedisHandshakeUnitTests {

public static final String ERR_UNKNOWN_COMMAND = "ERR unknown command 'CLIENT', with args beginning with: 'SETINFO' 'lib-name' 'Lettuce'";

@Test
void handshakeWithResp3ShouldPass() {

Expand Down Expand Up @@ -71,6 +75,34 @@ void handshakeWithDiscoveryShouldDowngrade() {
assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP2);
}

@Test
void handshakeFireAndForgetPostHandshake() {

EmbeddedChannel channel = new EmbeddedChannel(true, false);

ConnectionMetadata connectionMetdata = new ConnectionMetadata();
connectionMetdata.setLibraryName("library-name");
connectionMetdata.setLibraryVersion("library-version");

ConnectionState state = new ConnectionState();
state.setCredentialsProvider(new StaticCredentialsProvider(null, null));
state.apply(connectionMetdata);
RedisHandshake handshake = new RedisHandshake(null, false, state);
CompletionStage<Void> handshakeInit = handshake.initialize(channel);

AsyncCommand<String, String, Map<String, String>> hello = channel.readOutbound();
helloResponse(hello.getOutput());
hello.complete();

List<AsyncCommand<String, String, Map<String, String>>> postHandshake = channel.readOutbound();
postHandshake.get(0).getOutput().setError(ERR_UNKNOWN_COMMAND);
postHandshake.get(0).completeExceptionally(new RedisException(ERR_UNKNOWN_COMMAND));
postHandshake.get(0).complete();

assertThat(postHandshake.size()).isEqualTo(2);
assertThat(handshakeInit.toCompletableFuture().isCompletedExceptionally()).isFalse();
}

@Test
void shouldParseVersionWithCharacters() {

Expand Down

0 comments on commit 764fdf3

Please sign in to comment.