diff --git a/src/main/java/io/lettuce/core/RedisHandshake.java b/src/main/java/io/lettuce/core/RedisHandshake.java index 111d30328e..a7ccc7a9a6 100644 --- a/src/main/java/io/lettuce/core/RedisHandshake.java +++ b/src/main/java/io/lettuce/core/RedisHandshake.java @@ -38,6 +38,8 @@ import io.lettuce.core.protocol.ConnectionInitializer; import io.lettuce.core.protocol.ProtocolVersion; import io.netty.channel.Channel; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; /** * Redis RESP2/RESP3 handshake using the configured {@link ProtocolVersion} and other options for connection initialization and @@ -49,7 +51,7 @@ */ class RedisHandshake implements ConnectionInitializer { - private static final RedisVersion CLIENT_SET_INFO_SINCE = RedisVersion.of("7.2"); + private static final InternalLogger LOG = InternalLoggerFactory.getInstance(RedisHandshake.class); private final RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8); @@ -99,8 +101,21 @@ public CompletionStage initialize(Channel channel) { new RedisConnectionException("Protocol version" + this.requestedProtocolVersion + " not supported")); } - return handshake.thenCompose( - ignore -> applyPostHandshake(channel, connectionState.getRedisVersion(), getNegotiatedProtocolVersion())); + // post-handshake commands, whose execution failures would cause the connection to be considered + // unsuccessfully established + CompletableFuture postHandshake = applyPostHandshake(channel); + + // post-handshake commands, executed in a 'fire and forget' manner, to avoid having to react to different + // implementations or versions of the server runtime, and whose execution result (whether a success or a + // failure ) should not alter the outcome of the connection attempt + CompletableFuture connectionMetadata = applyConnectionMetadata(channel).handle((result, error) -> { + if (error != null) { + LOG.debug("Error applying connection metadata", error); + } + return null; + }); + + return handshake.thenCompose(ignore -> postHandshake).thenCompose(ignore -> connectionMetadata); } private CompletionStage tryHandshakeResp3(Channel channel) { @@ -237,41 +252,44 @@ private AsyncCommand> dispatchHello(Channel return dispatch(channel, this.commandBuilder.hello(3, null, null, connectionState.getClientName())); } - private CompletableFuture applyPostHandshake(Channel channel, String redisVersion, - ProtocolVersion negotiatedProtocolVersion) { + private CompletableFuture applyPostHandshake(Channel channel) { List> postHandshake = new ArrayList<>(); - ConnectionMetadata metadata = connectionState.getConnectionMetadata(); + if (connectionState.getDb() > 0) { + postHandshake.add(new AsyncCommand<>(this.commandBuilder.select(connectionState.getDb()))); + } - if (metadata.getClientName() != null && negotiatedProtocolVersion == ProtocolVersion.RESP2) { - postHandshake.add(new AsyncCommand<>(this.commandBuilder.clientSetname(connectionState.getClientName()))); + if (connectionState.isReadOnly()) { + postHandshake.add(new AsyncCommand<>(this.commandBuilder.readOnly())); } - if (negotiatedProtocolVersion == ProtocolVersion.RESP3) { + if (postHandshake.isEmpty()) { + return CompletableFuture.completedFuture(null); + } - RedisVersion currentVersion = RedisVersion.of(redisVersion); + return dispatch(channel, postHandshake); + } - if (currentVersion.isGreaterThanOrEqualTo(CLIENT_SET_INFO_SINCE)) { + private CompletableFuture applyConnectionMetadata(Channel channel) { - if (LettuceStrings.isNotEmpty(metadata.getLibraryName())) { - postHandshake - .add(new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-name", metadata.getLibraryName()))); - } + List> postHandshake = new ArrayList<>(); - if (LettuceStrings.isNotEmpty(metadata.getLibraryVersion())) { - postHandshake.add( - new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-ver", metadata.getLibraryVersion()))); - } - } + ConnectionMetadata metadata = connectionState.getConnectionMetadata(); + ProtocolVersion negotiatedProtocolVersion = getNegotiatedProtocolVersion(); + + if (metadata.getClientName() != null && negotiatedProtocolVersion == ProtocolVersion.RESP2) { + postHandshake.add(new AsyncCommand<>(this.commandBuilder.clientSetname(connectionState.getClientName()))); } - if (connectionState.getDb() > 0) { - postHandshake.add(new AsyncCommand<>(this.commandBuilder.select(connectionState.getDb()))); + if (LettuceStrings.isNotEmpty(metadata.getLibraryName())) { + postHandshake.add( + new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-name", metadata.getLibraryName()))); } - if (connectionState.isReadOnly()) { - postHandshake.add(new AsyncCommand<>(this.commandBuilder.readOnly())); + if (LettuceStrings.isNotEmpty(metadata.getLibraryVersion())) { + postHandshake.add( + new AsyncCommand<>(this.commandBuilder.clientSetinfo("lib-ver", metadata.getLibraryVersion()))); } if (postHandshake.isEmpty()) { diff --git a/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java b/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java index ad818c25d7..ed96e8be7f 100644 --- a/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java +++ b/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java @@ -3,7 +3,9 @@ import static org.assertj.core.api.Assertions.*; import java.nio.ByteBuffer; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletionStage; import org.junit.jupiter.api.Test; @@ -19,6 +21,8 @@ */ class RedisHandshakeUnitTests { + public static final String ERR_UNKNOWN_COMMAND = "ERR unknown command 'CLIENT', with args beginning with: 'SETINFO' 'lib-name' 'Lettuce'"; + @Test void handshakeWithResp3ShouldPass() { @@ -71,6 +75,34 @@ void handshakeWithDiscoveryShouldDowngrade() { assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP2); } + @Test + void handshakeFireAndForgetPostHandshake() { + + EmbeddedChannel channel = new EmbeddedChannel(true, false); + + ConnectionMetadata connectionMetdata = new ConnectionMetadata(); + connectionMetdata.setLibraryName("library-name"); + connectionMetdata.setLibraryVersion("library-version"); + + ConnectionState state = new ConnectionState(); + state.setCredentialsProvider(new StaticCredentialsProvider(null, null)); + state.apply(connectionMetdata); + RedisHandshake handshake = new RedisHandshake(null, false, state); + CompletionStage handshakeInit = handshake.initialize(channel); + + AsyncCommand> hello = channel.readOutbound(); + helloResponse(hello.getOutput()); + hello.complete(); + + List>> postHandshake = channel.readOutbound(); + postHandshake.get(0).getOutput().setError(ERR_UNKNOWN_COMMAND); + postHandshake.get(0).completeExceptionally(new RedisException(ERR_UNKNOWN_COMMAND)); + postHandshake.get(0).complete(); + + assertThat(postHandshake.size()).isEqualTo(2); + assertThat(handshakeInit.toCompletableFuture().isCompletedExceptionally()).isFalse(); + } + @Test void shouldParseVersionWithCharacters() {