Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RESP3 tests #2780

Merged
merged 18 commits into from
Jun 1, 2023
8 changes: 4 additions & 4 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,13 @@ def __init__(
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
if self.encoder.decode_responses:
self.health_check_response: Iterable[Union[str, bytes]] = [
"pong",
self.health_check_response = [
["pong", self.HEALTH_CHECK_MESSAGE],
self.HEALTH_CHECK_MESSAGE,
]
else:
self.health_check_response = [
b"pong",
[b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
if self.push_handler_func is None:
Expand Down Expand Up @@ -807,7 +807,7 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
conn, conn.read_response, timeout=read_timeout, push_request=True
)

if conn.health_check_interval and response == self.health_check_response:
if conn.health_check_interval and response in self.health_check_response:
# ignore the health check message as user might not expect it
return None
return response
Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def __init__(
kwargs.update({"retry": self.retry})

kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy()
if kwargs.get("protocol") in ["3", 3]:
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
kwargs["response_callbacks"].update(self.__class__.RESP3_RESPONSE_CALLBACKS)
self.connection_kwargs = kwargs

if startup_nodes:
Expand Down
27 changes: 24 additions & 3 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,16 +333,35 @@ def _error_message(self, exception):
async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
self._parser.on_connect(self)
parser = self._parser

auth_args = None
# if credential provider or username and/or password are set, authenticate
if self.credential_provider or (self.username or self.password):
cred_provider = (
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
if isinstance(self._parser, _AsyncRESP2Parser):
self.set_parser(_AsyncRESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = await self.read_response()
if response.get(b"proto") not in [2, "2"] and response.get(
"proto"
) not in [2, "2"]:
raise ConnectionError("Invalid RESP version")
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
elif auth_args:
await self.send_command("AUTH", *auth_args, check_health=False)

try:
Expand All @@ -359,9 +378,11 @@ async def on_connect(self) -> None:
raise AuthenticationError("Invalid Username or Password")

# if resp version is specified, switch to it
if self.protocol != 2:
elif self.protocol != 2:
if isinstance(self._parser, _AsyncRESP2Parser):
self.set_parser(_AsyncRESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol)
response = await self.read_response()
Expand Down
38 changes: 24 additions & 14 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,15 @@ def parse_xinfo_stream(response, **options):
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
else:
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
if isinstance(data["groups"][0], list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try/except instead == cheaper

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't work here because of the content of data["groups"]

data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
else:
data["groups"] = [
{str_if_bytes(k): v for k, v in group.items()}
for group in data["groups"]
]
return data


Expand Down Expand Up @@ -581,14 +587,15 @@ def parse_command_resp3(response, **options):
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = command[1]
cmd_dict["flags"] = command[2]
cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
cmd_dict["acl_categories"] = command[6]
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]

commands[cmd_name] = cmd_dict
return commands
Expand Down Expand Up @@ -626,17 +633,20 @@ def parse_acl_getuser(response, **options):
if data["channels"] == [""]:
data["channels"] = []
if "selectors" in data:
data["selectors"] = [
list(map(str_if_bytes, selector)) for selector in data["selectors"]
]
if data["selectors"] != [] and isinstance(data["selectors"][0], list):
data["selectors"] = [
list(map(str_if_bytes, selector)) for selector in data["selectors"]
]
elif data["selectors"] != []:
data["selectors"] = [
{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()}
for selector in data["selectors"]
]

# split 'commands' into separate 'categories' and 'commands' lists
commands, categories = [], []
for command in data["commands"].split(" "):
if "@" in command:
categories.append(command)
else:
commands.append(command)
categories.append(command) if "@" in command else commands.append(command)

data["commands"] = commands
data["categories"] = categories
Expand Down
22 changes: 19 additions & 3 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from redis.parsers import CommandsParser, Encoder
from redis.retry import Retry
from redis.utils import (
HIREDIS_AVAILABLE,
dict_merge,
list_keys_to_dict,
merge_result,
Expand Down Expand Up @@ -1608,7 +1609,15 @@ class ClusterPubSub(PubSub):
https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html
"""

def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
def __init__(
self,
redis_cluster,
node=None,
host=None,
port=None,
push_handler_func=None,
**kwargs,
):
"""
When a pubsub instance is created without specifying a node, a single
node will be transparently chosen for the pubsub connection on the
Expand All @@ -1633,7 +1642,10 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
self.node_pubsub_mapping = {}
self._pubsubs_generator = self._pubsubs_generator()
super().__init__(
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
connection_pool=connection_pool,
encoder=redis_cluster.encoder,
push_handler_func=push_handler_func,
**kwargs,
)

def set_pubsub_node(self, cluster, node=None, host=None, port=None):
Expand Down Expand Up @@ -1717,14 +1729,18 @@ def execute_command(self, *args):
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
self._execute(connection, connection.send_command, *args)

def _get_node_pubsub(self, node):
try:
return self.node_pubsub_mapping[node.name]
except KeyError:
pubsub = node.redis_connection.pubsub()
pubsub = node.redis_connection.pubsub(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good!

push_handler_func=self.push_handler_func
)
self.node_pubsub_mapping[node.name] = pubsub
return pubsub

Expand Down
23 changes: 22 additions & 1 deletion redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,33 @@ def _error_message(self, exception):
def on_connect(self):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
parser = self._parser

auth_args = None
# if credential provider or username and/or password are set, authenticate
if self.credential_provider or (self.username or self.password):
cred_provider = (
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol != 2:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = self.read_response()
if response.get(b"proto") != int(self.protocol) and response.get(
"proto"
) != int(self.protocol):
raise ConnectionError("Invalid RESP version")
elif auth_args:
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
self.send_command("AUTH", *auth_args, check_health=False)
Expand All @@ -302,9 +321,11 @@ def on_connect(self):
raise AuthenticationError("Invalid Username or Password")

# if resp version is specified, switch to it
if self.protocol != 2:
elif self.protocol != 2:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
self.send_command("HELLO", self.protocol)
response = self.read_response()
Expand Down
3 changes: 2 additions & 1 deletion redis/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseParser
from .base import BaseParser, _AsyncRESPBase
from .commands import AsyncCommandsParser, CommandsParser
from .encoders import Encoder
from .hiredis import _AsyncHiredisParser, _HiredisParser
Expand All @@ -8,6 +8,7 @@
__all__ = [
"AsyncCommandsParser",
"_AsyncHiredisParser",
"_AsyncRESPBase",
"_AsyncRESP2Parser",
"_AsyncRESP3Parser",
"CommandsParser",
Expand Down
14 changes: 10 additions & 4 deletions redis/parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ def _read_response(self, disable_decoding=False, push_request=False):
# bool value
elif byte == b"#":
return response == b"t"
# bulk response and verbatim strings
elif byte in (b"$", b"="):
# bulk response
elif byte == b"$":
response = self._buffer.read(int(response))
# verbatim string response
elif byte == b"=":
response = self._buffer.read(int(response))[4:]
# array response
elif byte == b"*":
response = [
Expand Down Expand Up @@ -195,9 +198,12 @@ async def _read_response(
# bool value
elif byte == b"#":
return response == b"t"
# bulk response and verbatim strings
elif byte in (b"$", b"="):
# bulk response
elif byte == b"$":
response = await self._read(int(response))
# verbatim string response
elif byte == b"=":
response = (await self._read(int(response)))[4:]
# array response
elif byte == b"*":
response = [
Expand Down
27 changes: 25 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,31 @@ def wait_for_command(client, monitor, command, key=None):


def is_resp2_connection(r):
if isinstance(r, redis.Redis):
if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis):
protocol = r.connection_pool.connection_kwargs.get("protocol")
elif isinstance(r, redis.RedisCluster):
elif isinstance(r, redis.cluster.AbstractRedisCluster):
protocol = r.nodes_manager.connection_kwargs.get("protocol")
return protocol in ["2", 2, None]


def get_protocol_version(r):
if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis):
return r.connection_pool.connection_kwargs.get("protocol")
elif isinstance(r, redis.cluster.AbstractRedisCluster):
return r.nodes_manager.connection_kwargs.get("protocol")


def assert_resp_response(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response == resp2_expected
else:
assert response == resp3_expected


def assert_resp_response_in(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response in resp2_expected
else:
assert response in resp3_expected
23 changes: 0 additions & 23 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,29 +236,6 @@ async def wait_for_command(
return None


def get_protocol_version(r):
if isinstance(r, redis.Redis):
return r.connection_pool.connection_kwargs.get("protocol")
elif isinstance(r, redis.RedisCluster):
return r.nodes_manager.connection_kwargs.get("protocol")


def assert_resp_response(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response == resp2_expected
else:
assert response == resp3_expected


def assert_resp_response_in(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response in resp2_expected
else:
assert response in resp3_expected


# python 3.6 doesn't have the asynccontextmanager decorator. Provide it here.
class AsyncContextManager:
def __init__(self, async_generator):
Expand Down
Loading