Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge existing changes to the forked version #1

Merged
merged 5 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 73 additions & 44 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,8 @@ def __init__(
self.retry = retry
kwargs.update({"retry": self.retry})
else:
kwargs.update({"retry": Retry(default_backoff(), 0)})
self.retry = Retry(default_backoff(), 0)
kwargs["retry"] = self.retry

self.encoder = Encoder(
kwargs.get("encoding", "utf-8"),
Expand Down Expand Up @@ -775,6 +776,7 @@ def pipeline(self, transaction=None, shard_hint=None):
read_from_replicas=self.read_from_replicas,
reinitialize_steps=self.reinitialize_steps,
lock=self._lock,
retry=self.retry,
)

def lock(
Expand Down Expand Up @@ -858,41 +860,49 @@ def set_response_callback(self, command, callback):
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()

nodes_flag = kwargs.pop("nodes_flag", None)
if nodes_flag is not None:
# nodes flag passed by the user
command_flag = nodes_flag
else:
# get the nodes group for this command if it was predefined
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
try:
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()

nodes_flag = kwargs.pop("nodes_flag", None)
if nodes_flag is not None:
# nodes flag passed by the user
command_flag = nodes_flag
else:
# get the nodes group for this command if it was predefined
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
except SlotNotCoveredError as e:
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
# Reset the counter
self.reinitialize_counter = 0
raise e

def _should_reinitialized(self):
# To reinitialize the cluster on every MOVED error,
Expand Down Expand Up @@ -1084,6 +1094,12 @@ def execute_command(self, *args, **kwargs):
# The nodes and slots cache were reinitialized.
# Try again with the new cluster setup.
retry_attempts -= 1
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(
self.cluster_error_retry_attempts - retry_attempts
)
if backoff > 0:
time.sleep(backoff)
continue
else:
# raise the exception
Expand Down Expand Up @@ -1143,8 +1159,6 @@ def _execute_command(self, target_node, *args, **kwargs):
# Remove the failed node from the startup nodes before we try
# to reinitialize the cluster
self.nodes_manager.startup_nodes.pop(target_node.name, None)
# Reset the cluster node's connection
target_node.redis_connection = None
self.nodes_manager.initialize()
raise e
except MovedError as e:
Expand All @@ -1164,6 +1178,13 @@ def _execute_command(self, target_node, *args, **kwargs):
else:
self.nodes_manager.update_moved_exception(e)
moved = True
except SlotNotCoveredError as e:
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
# Reset the counter
self.reinitialize_counter = 0
raise e
except TryAgainError:
if ttl < self.RedisClusterRequestTTL / 2:
time.sleep(0.05)
Expand Down Expand Up @@ -1397,7 +1418,10 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
# randomly choose one of the replicas
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)

return self.slots_cache[slot][node_idx]
try:
return self.slots_cache[slot][node_idx]
except IndexError:
return self.slots_cache[slot][0]

def get_nodes_by_server_type(self, server_type):
"""
Expand Down Expand Up @@ -1774,6 +1798,7 @@ def __init__(
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 5,
lock=None,
retry: Optional["Retry"] = None,
**kwargs,
):
""" """
Expand All @@ -1799,6 +1824,7 @@ def __init__(
if lock is None:
lock = threading.Lock()
self._lock = lock
self.retry = retry

def __repr__(self):
""" """
Expand Down Expand Up @@ -1931,8 +1957,9 @@ def send_cluster_commands(
stack,
raise_on_error=raise_on_error,
allow_redirections=allow_redirections,
attempts_count=self.cluster_error_retry_attempts - retry_attempts,
)
except (ClusterDownError, ConnectionError) as e:
except (ClusterDownError, ConnectionError, TimeoutError) as e:
if retry_attempts > 0:
# Try again with the new cluster setup. All other errors
# should be raised.
Expand All @@ -1942,7 +1969,7 @@ def send_cluster_commands(
raise e

def _send_cluster_commands(
self, stack, raise_on_error=True, allow_redirections=True
self, stack, raise_on_error=True, allow_redirections=True, attempts_count=0
):
"""
Send a bunch of cluster commands to the redis cluster.
Expand Down Expand Up @@ -1997,9 +2024,11 @@ def _send_cluster_commands(
redis_node = self.get_redis_connection(node)
try:
connection = get_connection(redis_node, c.args)
except ConnectionError:
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
except (ConnectionError, TimeoutError) as e:
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(attempts_count)
if backoff > 0:
time.sleep(backoff)
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
Expand Down
87 changes: 47 additions & 40 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import socketserver
import threading
import uuid
import warnings
from queue import LifoQueue, Queue
from time import sleep
Expand All @@ -12,7 +13,12 @@
import pytest

from redis import Redis
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
from redis.backoff import (
ConstantBackoff,
ExponentialBackoff,
NoBackoff,
default_backoff,
)
from redis.cluster import (
PRIMARY,
REDIS_CLUSTER_HASH_SLOTS,
Expand All @@ -35,6 +41,7 @@
RedisClusterException,
RedisError,
ResponseError,
SlotNotCoveredError,
TimeoutError,
)
from redis.retry import Retry
Expand Down Expand Up @@ -788,45 +795,6 @@ def test_not_require_full_coverage_cluster_down_error(self, r):
else:
raise e

def test_timeout_error_topology_refresh_reuse_connections(self, r):
"""
By mucking TIMEOUT errors, we'll force the cluster topology to be reinitialized,
and then ensure that only the impacted connection is replaced
"""
node = r.get_node_from_key("key")
r.set("key", "value")
node_conn_origin = {}
for n in r.get_nodes():
node_conn_origin[n.name] = n.redis_connection
real_func = r.get_redis_connection(node).parse_response

class counter:
def __init__(self, val=0):
self.val = int(val)

count = counter(0)
with patch.object(Redis, "parse_response") as parse_response:

def moved_redirect_effect(connection, *args, **options):
# raise a timeout for 5 times so we'll need to reinitialize the topology
if count.val == 4:
parse_response.side_effect = real_func
count.val += 1
raise TimeoutError()

parse_response.side_effect = moved_redirect_effect
assert r.get("key") == b"value"
for node_name, conn in node_conn_origin.items():
if node_name == node.name:
# The old redis connection of the timed out node should have been
# deleted and replaced
assert conn != r.get_redis_connection(node)
else:
# other nodes' redis connection should have been reused during the
# topology refresh
cur_node = r.get_node(node_name=node_name)
assert conn == r.get_redis_connection(cur_node)

def test_cluster_get_set_retry_object(self, request):
retry = Retry(NoBackoff(), 2)
r = _get_client(RedisCluster, request, retry=retry)
Expand Down Expand Up @@ -939,6 +907,45 @@ def address_remap(address):
n_used = sum((1 if p.n_connections else 0) for p in proxies)
assert n_used > 1

@pytest.mark.parametrize("error", [ConnectionError, TimeoutError])
def test_additional_backoff_redis_cluster(self, error):
with patch.object(ConstantBackoff, "compute") as compute:

def _compute(target_node, *args, **kwargs):
return 1

compute.side_effect = _compute
with patch.object(RedisCluster, "_execute_command") as execute_command:

def raise_error(target_node, *args, **kwargs):
execute_command.failed_calls += 1
raise error("mocked error")

execute_command.side_effect = raise_error

rc = get_mocked_redis_client(
host=default_host,
port=default_port,
retry=Retry(ConstantBackoff(1), 3),
)

with pytest.raises(error):
rc.get("bar")
assert compute.call_count == rc.cluster_error_retry_attempts

@pytest.mark.parametrize("reinitialize_steps", [2, 10, 99])
def test_recover_slot_not_covered_error(self, request, reinitialize_steps):
rc = _get_client(RedisCluster, request, reinitialize_steps=reinitialize_steps)
key = uuid.uuid4().hex

rc.nodes_manager.slots_cache[rc.keyslot(key)] = []

for _ in range(0, reinitialize_steps):
with pytest.raises(SlotNotCoveredError):
rc.get(key)

rc.get(key)


@pytest.mark.onlycluster
class TestClusterRedisCommands:
Expand Down