Skip to content

Commit

Permalink
Merge existing changes to the forked version (#1)
Browse files Browse the repository at this point in the history
* [GROW-2938] do not reset redis_connection on an error

* [GROW-2938] add backoff to more errors

* [GROW-2938] recover from SlotNotCoveredError

* [GROW-2938] prevent get_node_from_slot from failing due to concurrent cluster slots refresh

* [GROW-2938] add retry to ClusterPipeline

(cherry picked from commit 63e06dd)
  • Loading branch information
zach-iee committed Jun 23, 2023
1 parent 49d9cb7 commit 28d80e1
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 84 deletions.
117 changes: 73 additions & 44 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ def __init__(
self.retry = retry
kwargs.update({"retry": self.retry})
else:
kwargs.update({"retry": Retry(default_backoff(), 0)})
self.retry = Retry(default_backoff(), 0)
kwargs["retry"] = self.retry

self.encoder = Encoder(
kwargs.get("encoding", "utf-8"),
Expand Down Expand Up @@ -759,6 +760,7 @@ def pipeline(self, transaction=None, shard_hint=None):
read_from_replicas=self.read_from_replicas,
reinitialize_steps=self.reinitialize_steps,
lock=self._lock,
retry=self.retry,
)

def lock(
Expand Down Expand Up @@ -842,41 +844,49 @@ def set_response_callback(self, command, callback):
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()

nodes_flag = kwargs.pop("nodes_flag", None)
if nodes_flag is not None:
# nodes flag passed by the user
command_flag = nodes_flag
else:
# get the nodes group for this command if it was predefined
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
try:
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()

nodes_flag = kwargs.pop("nodes_flag", None)
if nodes_flag is not None:
# nodes flag passed by the user
command_flag = nodes_flag
else:
# get the nodes group for this command if it was predefined
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
except SlotNotCoveredError as e:
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
# Reset the counter
self.reinitialize_counter = 0
raise e

def _should_reinitialized(self):
# To reinitialize the cluster on every MOVED error,
Expand Down Expand Up @@ -1068,6 +1078,12 @@ def execute_command(self, *args, **kwargs):
# The nodes and slots cache were reinitialized.
# Try again with the new cluster setup.
retry_attempts -= 1
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(
self.cluster_error_retry_attempts - retry_attempts
)
if backoff > 0:
time.sleep(backoff)
continue
else:
# raise the exception
Expand Down Expand Up @@ -1127,8 +1143,6 @@ def _execute_command(self, target_node, *args, **kwargs):
# Remove the failed node from the startup nodes before we try
# to reinitialize the cluster
self.nodes_manager.startup_nodes.pop(target_node.name, None)
# Reset the cluster node's connection
target_node.redis_connection = None
self.nodes_manager.initialize()
raise e
except MovedError as e:
Expand All @@ -1148,6 +1162,13 @@ def _execute_command(self, target_node, *args, **kwargs):
else:
self.nodes_manager.update_moved_exception(e)
moved = True
except SlotNotCoveredError as e:
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
# Reset the counter
self.reinitialize_counter = 0
raise e
except TryAgainError:
if ttl < self.RedisClusterRequestTTL / 2:
time.sleep(0.05)
Expand Down Expand Up @@ -1379,7 +1400,10 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
# randomly choose one of the replicas
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)

return self.slots_cache[slot][node_idx]
try:
return self.slots_cache[slot][node_idx]
except IndexError:
return self.slots_cache[slot][0]

def get_nodes_by_server_type(self, server_type):
"""
Expand Down Expand Up @@ -1744,6 +1768,7 @@ def __init__(
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 5,
lock=None,
retry: Optional["Retry"] = None,
**kwargs,
):
""" """
Expand All @@ -1769,6 +1794,7 @@ def __init__(
if lock is None:
lock = threading.Lock()
self._lock = lock
self.retry = retry

def __repr__(self):
""" """
Expand Down Expand Up @@ -1901,8 +1927,9 @@ def send_cluster_commands(
stack,
raise_on_error=raise_on_error,
allow_redirections=allow_redirections,
attempts_count=self.cluster_error_retry_attempts - retry_attempts,
)
except (ClusterDownError, ConnectionError) as e:
except (ClusterDownError, ConnectionError, TimeoutError) as e:
if retry_attempts > 0:
# Try again with the new cluster setup. All other errors
# should be raised.
Expand All @@ -1912,7 +1939,7 @@ def send_cluster_commands(
raise e

def _send_cluster_commands(
self, stack, raise_on_error=True, allow_redirections=True
self, stack, raise_on_error=True, allow_redirections=True, attempts_count=0
):
"""
Send a bunch of cluster commands to the redis cluster.
Expand Down Expand Up @@ -1967,9 +1994,11 @@ def _send_cluster_commands(
redis_node = self.get_redis_connection(node)
try:
connection = get_connection(redis_node, c.args)
except ConnectionError:
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
except (ConnectionError, TimeoutError) as e:
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(attempts_count)
if backoff > 0:
time.sleep(backoff)
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
Expand Down
87 changes: 47 additions & 40 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import binascii
import datetime
import uuid
import warnings
from time import sleep
from unittest.mock import DEFAULT, Mock, call, patch

import pytest

from redis import Redis
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
from redis.backoff import (
ConstantBackoff,
ExponentialBackoff,
NoBackoff,
default_backoff,
)
from redis.cluster import (
PRIMARY,
REDIS_CLUSTER_HASH_SLOTS,
Expand All @@ -30,6 +36,7 @@
RedisClusterException,
RedisError,
ResponseError,
SlotNotCoveredError,
TimeoutError,
)
from redis.retry import Retry
Expand Down Expand Up @@ -716,45 +723,6 @@ def test_not_require_full_coverage_cluster_down_error(self, r):
else:
raise e

def test_timeout_error_topology_refresh_reuse_connections(self, r):
"""
By mucking TIMEOUT errors, we'll force the cluster topology to be reinitialized,
and then ensure that only the impacted connection is replaced
"""
node = r.get_node_from_key("key")
r.set("key", "value")
node_conn_origin = {}
for n in r.get_nodes():
node_conn_origin[n.name] = n.redis_connection
real_func = r.get_redis_connection(node).parse_response

class counter:
def __init__(self, val=0):
self.val = int(val)

count = counter(0)
with patch.object(Redis, "parse_response") as parse_response:

def moved_redirect_effect(connection, *args, **options):
# raise a timeout for 5 times so we'll need to reinitialize the topology
if count.val == 4:
parse_response.side_effect = real_func
count.val += 1
raise TimeoutError()

parse_response.side_effect = moved_redirect_effect
assert r.get("key") == b"value"
for node_name, conn in node_conn_origin.items():
if node_name == node.name:
# The old redis connection of the timed out node should have been
# deleted and replaced
assert conn != r.get_redis_connection(node)
else:
# other nodes' redis connection should have been reused during the
# topology refresh
cur_node = r.get_node(node_name=node_name)
assert conn == r.get_redis_connection(cur_node)

def test_cluster_get_set_retry_object(self, request):
retry = Retry(NoBackoff(), 2)
r = _get_client(RedisCluster, request, retry=retry)
Expand Down Expand Up @@ -822,6 +790,45 @@ def raise_connection_error():
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
assert r.get_default_node() != curr_default_node

@pytest.mark.parametrize("error", [ConnectionError, TimeoutError])
def test_additional_backoff_redis_cluster(self, error):
with patch.object(ConstantBackoff, "compute") as compute:

def _compute(target_node, *args, **kwargs):
return 1

compute.side_effect = _compute
with patch.object(RedisCluster, "_execute_command") as execute_command:

def raise_error(target_node, *args, **kwargs):
execute_command.failed_calls += 1
raise error("mocked error")

execute_command.side_effect = raise_error

rc = get_mocked_redis_client(
host=default_host,
port=default_port,
retry=Retry(ConstantBackoff(1), 3),
)

with pytest.raises(error):
rc.get("bar")
assert compute.call_count == rc.cluster_error_retry_attempts

@pytest.mark.parametrize("reinitialize_steps", [2, 10, 99])
def test_recover_slot_not_covered_error(self, request, reinitialize_steps):
rc = _get_client(RedisCluster, request, reinitialize_steps=reinitialize_steps)
key = uuid.uuid4().hex

rc.nodes_manager.slots_cache[rc.keyslot(key)] = []

for _ in range(0, reinitialize_steps):
with pytest.raises(SlotNotCoveredError):
rc.get(key)

rc.get(key)


@pytest.mark.onlycluster
class TestClusterRedisCommands:
Expand Down

0 comments on commit 28d80e1

Please sign in to comment.