From e9adcb3aadd05131b2931e47ff1e7dc4cf6be207 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 3 May 2023 19:49:07 +0300 Subject: [PATCH 01/16] fix command response in resp3 --- redis/client.py | 21 +++++++++++++++++++++ redis/parsers/resp3.py | 20 ++++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/redis/client.py b/redis/client.py index 71048f548f..565f133d6f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -574,6 +574,26 @@ def parse_command(response, **options): return commands +def parse_command_resp3(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = command[1] + cmd_dict["flags"] = command[2] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + cmd_dict["acl_categories"] = command[6] + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + + commands[cmd_name] = cmd_dict + return commands + + def parse_pubsub_numsub(response, **options): return list(zip(response[0::2], response[1::2])) @@ -874,6 +894,7 @@ class AbstractRedis: if isinstance(r, list) else bool_ok(r), **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "COMMAND": parse_command_resp3, "STRALGO": lambda r, **options: { str_if_bytes(key): str_if_bytes(value) for key, value in r.items() } diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index 93fb6ff554..f9ef732840 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -80,10 +80,16 @@ def _read_response(self, disable_decoding=False, push_request=False): ] # set response elif byte == b"~": - response = { + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ self._read_response(disable_decoding=disable_decoding) for _ in range(int(response)) - } + ] + try: + response = set(response) + except TypeError as e: + pass # map response elif byte == b"%": response = { @@ -199,10 +205,16 @@ async def _read_response( ] # set response elif byte == b"~": - response = { + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ (await self._read_response(disable_decoding=disable_decoding)) for _ in range(int(response)) - } + ] + try: + response = set(response) + except TypeError: + pass # map response elif byte == b"%": response = { From 32e46a73a786aba8b57b8051aba7e4638af8663b Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 3 May 2023 19:50:21 +0300 Subject: [PATCH 02/16] linters --- redis/parsers/resp3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index f9ef732840..5cd7f388dd 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -88,7 +88,7 @@ def _read_response(self, disable_decoding=False, push_request=False): ] try: response = set(response) - except TypeError as e: + except TypeError: pass # map response elif byte == b"%": From c4d1baa77e3c301436c677becdf728745cf64ac2 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 4 May 2023 03:05:55 +0300 Subject: [PATCH 03/16] acl_log & acl_getuser --- redis/client.py | 14 +++++++------- redis/connection.py | 16 +++++++++++++++- tests/test_commands.py | 11 ++++++++--- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/redis/client.py b/redis/client.py index 565f133d6f..30c24ce383 100755 --- a/redis/client.py +++ b/redis/client.py @@ -626,17 +626,17 @@ def parse_acl_getuser(response, **options): if data["channels"] == [""]: data["channels"] = [] if "selectors" in data: - data["selectors"] = [ - list(map(str_if_bytes, selector)) for selector in data["selectors"] - ] + if data["selectors"] != [] and isinstance(data["selectors"][0], list): + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] + elif data["selectors"] != []: + data["selectors"] = [{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} for selector in data["selectors"]] # split 'commands' into separate 'categories' and 'commands' lists commands, categories = [], [] for command in data["commands"].split(" "): - if "@" in command: - categories.append(command) - else: - commands.append(command) + categories.append(command) if "@" in command else commands.append(command) data["commands"] = commands data["categories"] = categories diff --git a/redis/connection.py b/redis/connection.py index 19c80e08f5..76430773a0 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -277,6 +277,7 @@ def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -284,6 +285,19 @@ def on_connect(self): or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol != 2: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + self._parser.on_connect(self) + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + elif auth_args: # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -302,7 +316,7 @@ def on_connect(self): raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - if self.protocol != 2: + elif self.protocol != 2: if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) self._parser.on_connect(self) diff --git a/tests/test_commands.py b/tests/test_commands.py index 1af69c83c0..75af1ffa5e 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -317,9 +317,14 @@ def teardown(): assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 assert set(acl["channels"]) == {"&message:*"} - assert acl["selectors"] == [ - ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] - ] + if is_resp2_connection(r): + assert acl["selectors"] == [ + ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] + ] + else: + assert acl["selectors"] == [ + {"commands": "-@all +set", "keys": "%W~app*", "channels": ""} + ] @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): From 0360c13c5e6ea2ce6a9fdb1930ad3fdc88413660 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 4 May 2023 03:13:31 +0300 Subject: [PATCH 04/16] client_info --- redis/parsers/resp3.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index 5cd7f388dd..0220a8f6f2 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -69,9 +69,12 @@ def _read_response(self, disable_decoding=False, push_request=False): # bool value elif byte == b"#": return response == b"t" - # bulk response and verbatim strings - elif byte in (b"$", b"="): + # bulk response + elif byte == b"$": response = self._buffer.read(int(response)) + # verbatim string response + elif byte == b"=": + response = self._buffer.read(int(response))[4:] # array response elif byte == b"*": response = [ @@ -194,9 +197,12 @@ async def _read_response( # bool value elif byte == b"#": return response == b"t" - # bulk response and verbatim strings - elif byte in (b"$", b"="): + # bulk response + elif byte == b"$": response = await self._read(int(response)) + # verbatim string response + elif byte == b"=": + response = await self._read(int(response))[4:] # array response elif byte == b"*": response = [ From fb8e4619bc16ce7449c4c2002aab8c741e970d87 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 28 May 2023 04:33:08 +0300 Subject: [PATCH 05/16] test_commands and test_asyncio/test_commands --- redis/asyncio/connection.py | 20 +++++++++-- redis/client.py | 12 +++++-- redis/connection.py | 2 ++ redis/parsers/resp3.py | 2 +- tests/test_asyncio/test_commands.py | 6 +++- tests/test_commands.py | 51 +++++++++++++++++++++-------- 6 files changed, 72 insertions(+), 21 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bc872ff358..8b5f7de27b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -334,6 +334,7 @@ async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -341,8 +342,21 @@ async def on_connect(self) -> None: or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol != 2: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = await self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + elif auth_args: await self.send_command("AUTH", *auth_args, check_health=False) try: @@ -359,7 +373,7 @@ async def on_connect(self) -> None: raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - if self.protocol != 2: + elif self.protocol != 2: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) self._parser.on_connect(self) diff --git a/redis/client.py b/redis/client.py index 548d259b45..4f07b097a3 100755 --- a/redis/client.py +++ b/redis/client.py @@ -331,9 +331,15 @@ def parse_xinfo_stream(response, **options): data["last-entry"] = (last[0], pairs_to_dict(last[1])) else: data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - data["groups"] = [ - pairs_to_dict(group, decode_keys=True) for group in data["groups"] - ] + if isinstance(data["groups"][0], list): + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + else: + data["groups"] = [ + {str_if_bytes(k): v for k, v in group.items()} + for group in data["groups"] + ] return data diff --git a/redis/connection.py b/redis/connection.py index 76430773a0..4890d81f3a 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -291,6 +291,8 @@ def on_connect(self): if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] self.send_command("HELLO", self.protocol, "AUTH", *auth_args) response = self.read_response() if response.get(b"proto") != int(self.protocol) and response.get( diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index 0220a8f6f2..854554b277 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -202,7 +202,7 @@ async def _read_response( response = await self._read(int(response)) # verbatim string response elif byte == b"=": - response = await self._read(int(response))[4:] + response = (await self._read(int(response)))[4:] # array response elif byte == b"*": response = [ diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 866929b2e4..7f9364d557 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -73,7 +73,11 @@ class TestResponseCallbacks: """Tests for the response callback system""" async def test_response_callbacks(self, r: redis.Redis): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS + resp3_callbacks = redis.Redis.RESPONSE_CALLBACKS.copy() + resp3_callbacks.update(redis.Redis.RESP3_RESPONSE_CALLBACKS) + assert_resp_response( + r, r.response_callbacks, redis.Redis.RESPONSE_CALLBACKS, resp3_callbacks + ) assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) r.set_response_callback("GET", lambda x: "static") await r.set("a", "foo") diff --git a/tests/test_commands.py b/tests/test_commands.py index 75af1ffa5e..b638dcb039 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -56,7 +56,10 @@ class TestResponseCallbacks: "Tests for the response callback system" def test_response_callbacks(self, r): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS + callbacks = redis.Redis.RESPONSE_CALLBACKS + if not is_resp2_connection(r): + callbacks.update(redis.Redis.RESP3_RESPONSE_CALLBACKS) + assert r.response_callbacks == callbacks assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) r.set_response_callback("GET", lambda x: "static") r["a"] = "foo" @@ -1129,7 +1132,10 @@ def test_lcs(self, r): r.mset({"foo": "ohmytext", "bar": "mynewtext"}) assert r.lcs("foo", "bar") == b"mytext" assert r.lcs("foo", "bar", len=True) == 6 - result = [b"matches", [[[4, 7], [5, 8]]], b"len", 6] + if is_resp2_connection(r): + result = [b"matches", [[[4, 7], [5, 8]]], b"len", 6] + else: + result = {b"matches": [[[4, 7], [5, 8]]], b"len": 6} assert r.lcs("foo", "bar", idx=True, minmatchlen=3) == result with pytest.raises(redis.ResponseError): assert r.lcs("foo", "bar", len=True, idx=True) @@ -2532,23 +2538,36 @@ def test_bzpopmin(self, r): @skip_if_server_version_lt("7.0.0") def test_zmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] + if is_resp2_connection(r): + res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] + else: + res = [b"a", [[b"a1", 1.0], [b"a2", 2.0]]] assert r.zmpop("2", ["b", "a"], min=True, count=2) == res with pytest.raises(redis.DataError): r.zmpop("2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - assert r.zmpop("2", ["b", "a"], max=True) == [b"b", [[b"b1", b"10"]]] + if is_resp2_connection(r): + res = [b"b", [[b"b1", b"10"]]] + else: + res = [b"b", [[b"b1", 10.0]]] + assert r.zmpop("2", ["b", "a"], max=True) == res @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_bzmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] + if is_resp2_connection(r): + res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] + else: + res = [b"a", [[b"a1", 1.0], [b"a2", 2.0]]] assert r.bzmpop(1, "2", ["b", "a"], min=True, count=2) == res with pytest.raises(redis.DataError): r.bzmpop(1, "2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - res = [b"b", [[b"b1", b"10"]]] + if is_resp2_connection(r): + res = [b"b", [[b"b1", b"10"]]] + else: + res = [b"b", [[b"b1", 10.0]]] assert r.bzmpop(0, "2", ["b", "a"], max=True) == res assert r.bzmpop(1, "2", ["foo", "bar"], max=True) is None @@ -4025,7 +4044,7 @@ def test_xadd_explicit_ms(self, r: redis.Redis): ms = message_id[: message_id.index(b"-")] assert ms == b"9999999999999999999" - @skip_if_server_version_lt("6.2.0") + @skip_if_server_version_lt("7.0.0") def test_xautoclaim(self, r): stream = "stream" group = "group" @@ -4040,7 +4059,7 @@ def test_xautoclaim(self, r): # trying to claim a message that isn't already pending doesn't # do anything response = r.xautoclaim(stream, group, consumer2, min_idle_time=0) - assert response == [b"0-0", []] + assert response == [b"0-0", [], []] # read the group as consumer1 to initially claim the messages r.xreadgroup(group, consumer1, streams={stream: ">"}) @@ -4335,7 +4354,7 @@ def test_xinfo_stream_full(self, r): if is_resp2_connection(r): assert m1 in info["entries"] else: - assert m1 in info["entries"][0] + assert m1 in info["entries"].keys() assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4874,10 +4893,16 @@ def test_command(self, r): @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_command_getkeysandflags(self, r: redis.Redis): - res = [ - [b"mylist1", [b"RW", b"access", b"delete"]], - [b"mylist2", [b"RW", b"insert"]], - ] + if is_resp2_connection(r): + res = [ + [b"mylist1", [b"RW", b"access", b"delete"]], + [b"mylist2", [b"RW", b"insert"]], + ] + else: + res = [ + [b"mylist1", {b"RW", b"access", b"delete"}], + [b"mylist2", {b"RW", b"insert"}], + ] assert res == r.command_getkeysandflags( "LMOVE", "mylist1", "mylist2", "left", "left" ) From bc44be5d6ddea138c1700b5004bec2dfb2c3d777 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 28 May 2023 15:40:49 +0300 Subject: [PATCH 06/16] fix test_command_parser --- redis/client.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/redis/client.py b/redis/client.py index 4f07b097a3..19dc8c0013 100755 --- a/redis/client.py +++ b/redis/client.py @@ -587,14 +587,15 @@ def parse_command_resp3(response, **options): cmd_name = str_if_bytes(command[0]) cmd_dict["name"] = cmd_name cmd_dict["arity"] = command[1] - cmd_dict["flags"] = command[2] + cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} cmd_dict["first_key_pos"] = command[3] cmd_dict["last_key_pos"] = command[4] cmd_dict["step_count"] = command[5] cmd_dict["acl_categories"] = command[6] - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] commands[cmd_name] = cmd_dict return commands From 89eb5768076956b96708172f38ccacbe3048a81a Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 28 May 2023 15:53:59 +0300 Subject: [PATCH 07/16] fix asyncio/test_connection/test_invalid_response --- redis/parsers/__init__.py | 3 ++- tests/test_asyncio/test_connection.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py index 0586016a61..1c092017ef 100644 --- a/redis/parsers/__init__.py +++ b/redis/parsers/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseParser +from .base import _AsyncRESPBase, BaseParser from .commands import AsyncCommandsParser, CommandsParser from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser @@ -8,6 +8,7 @@ __all__ = [ "AsyncCommandsParser", "_AsyncHiredisParser", + "_AsyncRESPBase" "_AsyncRESP2Parser", "_AsyncRESP3Parser", "CommandsParser", diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 3a8cf8d9c2..f502739c90 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -11,7 +11,7 @@ from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser, _AsyncRESP3Parser +from redis.parsers import _AsyncHiredisParser, _AsyncRESPBase, _AsyncRESP2Parser, _AsyncRESP3Parser from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt @@ -26,11 +26,11 @@ async def test_invalid_response(create_redis): raw = b"x" fake_stream = MockStream(raw + b"\r\n") - parser: _AsyncRESP2Parser = r.connection._parser + parser: _AsyncRESPBase = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - if isinstance(parser, _AsyncRESP2Parser): + if isinstance(parser, _AsyncRESPBase): assert str(cm.value) == f"Protocol Error: {raw!r}" else: assert ( From 3c511e0451605a7c605db284123feae8d572bfb3 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 29 May 2023 02:07:54 +0300 Subject: [PATCH 08/16] linters --- redis/asyncio/connection.py | 2 ++ redis/client.py | 5 ++++- redis/parsers/__init__.py | 4 ++-- tests/test_asyncio/test_connection.py | 7 ++++++- tests/test_asyncio/test_pipeline.py | 2 -- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 8b5f7de27b..d78fd7afd7 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -348,6 +348,8 @@ async def on_connect(self) -> None: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) response = await self.read_response() if response.get(b"proto") != int(self.protocol) and response.get( diff --git a/redis/client.py b/redis/client.py index 19dc8c0013..4aa8f7010a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -638,7 +638,10 @@ def parse_acl_getuser(response, **options): list(map(str_if_bytes, selector)) for selector in data["selectors"] ] elif data["selectors"] != []: - data["selectors"] = [{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} for selector in data["selectors"]] + data["selectors"] = [ + {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} + for selector in data["selectors"] + ] # split 'commands' into separate 'categories' and 'commands' lists commands, categories = [], [] diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py index 1c092017ef..6cc32e3cae 100644 --- a/redis/parsers/__init__.py +++ b/redis/parsers/__init__.py @@ -1,4 +1,4 @@ -from .base import _AsyncRESPBase, BaseParser +from .base import BaseParser, _AsyncRESPBase from .commands import AsyncCommandsParser, CommandsParser from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser @@ -8,7 +8,7 @@ __all__ = [ "AsyncCommandsParser", "_AsyncHiredisParser", - "_AsyncRESPBase" + "_AsyncRESPBase", "_AsyncRESP2Parser", "_AsyncRESP3Parser", "CommandsParser", diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index f502739c90..c5b21055e0 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -11,7 +11,12 @@ from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import _AsyncHiredisParser, _AsyncRESPBase, _AsyncRESP2Parser, _AsyncRESP3Parser +from redis.parsers import ( + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, + _AsyncRESPBase, +) from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 3df57eb90f..b29aa53487 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -21,7 +21,6 @@ async def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert await pipe.execute() == [ True, @@ -29,7 +28,6 @@ async def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] async def test_pipeline_memoryview(self, r): From 89ef178b7b655bce5ea145a5dadc2ccb5a4ab105 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 31 May 2023 23:27:40 +0300 Subject: [PATCH 09/16] all the tests --- redis/asyncio/client.py | 6 +- redis/asyncio/cluster.py | 2 + redis/asyncio/connection.py | 5 + redis/cluster.py | 18 ++- redis/connection.py | 5 + tests/conftest.py | 23 ++++ tests/test_asyncio/test_cluster.py | 177 ++++++++++++++----------- tests/test_asyncio/test_commands.py | 39 +++--- tests/test_asyncio/test_pubsub.py | 2 +- tests/test_cluster.py | 193 ++++++++++++++++------------ tests/test_commands.py | 2 +- tests/test_function.py | 22 +++- 12 files changed, 304 insertions(+), 190 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 2cd2daddcc..be70143504 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -672,12 +672,12 @@ def __init__( self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: self.health_check_response: Iterable[Union[str, bytes]] = [ - "pong", + ["pong", self.HEALTH_CHECK_MESSAGE], self.HEALTH_CHECK_MESSAGE, ] else: self.health_check_response = [ - b"pong", + [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] if self.push_handler_func is None: @@ -807,7 +807,7 @@ async def parse_response(self, block: bool = True, timeout: float = 0): conn, conn.read_response, timeout=read_timeout, push_request=True ) - if conn.health_check_interval and response == self.health_check_response: + if conn.health_check_interval and response in self.health_check_response: # ignore the health check message as user might not expect it return None return response diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 525c17b22d..4a606ad38f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -319,6 +319,8 @@ def __init__( kwargs.update({"retry": self.retry}) kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + if kwargs.get("protocol") in ["3", 3]: + kwargs["response_callbacks"].update(self.__class__.RESP3_RESPONSE_CALLBACKS) self.connection_kwargs = kwargs if startup_nodes: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d78fd7afd7..ae4d67b49a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -333,6 +333,7 @@ def _error_message(self, exception): async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) + parser = self._parser auth_args = None # if credential provider or username and/or password are set, authenticate @@ -347,6 +348,8 @@ async def on_connect(self) -> None: if auth_args and self.protocol != 2: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) if len(auth_args) == 1: auth_args = ["default", auth_args[0]] @@ -378,6 +381,8 @@ async def on_connect(self) -> None: elif self.protocol != 2: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) await self.send_command("HELLO", self.protocol) response = await self.read_response() diff --git a/redis/cluster.py b/redis/cluster.py index 182ec6d733..0463dd227c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -32,6 +32,7 @@ from redis.parsers import CommandsParser, Encoder from redis.retry import Retry from redis.utils import ( + HIREDIS_AVAILABLE, dict_merge, list_keys_to_dict, merge_result, @@ -1603,7 +1604,15 @@ class ClusterPubSub(PubSub): https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html """ - def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): + def __init__( + self, + redis_cluster, + node=None, + host=None, + port=None, + push_handler_func=None, + **kwargs, + ): """ When a pubsub instance is created without specifying a node, a single node will be transparently chosen for the pubsub connection on the @@ -1626,7 +1635,10 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): ) self.cluster = redis_cluster super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder + **kwargs, + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, ) def set_pubsub_node(self, cluster, node=None, host=None, port=None): @@ -1710,6 +1722,8 @@ def execute_command(self, *args, **kwargs): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) diff --git a/redis/connection.py b/redis/connection.py index 4890d81f3a..ee3bece11c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -276,6 +276,7 @@ def _error_message(self, exception): def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) + parser = self._parser auth_args = None # if credential provider or username and/or password are set, authenticate @@ -290,6 +291,8 @@ def on_connect(self): if auth_args and self.protocol != 2: if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) if len(auth_args) == 1: auth_args = ["default", auth_args[0]] @@ -321,6 +324,8 @@ def on_connect(self): elif self.protocol != 2: if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) self.send_command("HELLO", self.protocol) response = self.read_response() diff --git a/tests/conftest.py b/tests/conftest.py index c471f3d837..1b98107c74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -480,3 +480,26 @@ def is_resp2_connection(r): elif isinstance(r, redis.RedisCluster): protocol = r.nodes_manager.connection_kwargs.get("protocol") return protocol in ["2", 2, None] + + +def get_protocol_version(r): + if isinstance(r, redis.Redis): + return r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.cluster.AbstractRedisCluster): + return r.nodes_manager.connection_kwargs.get("protocol") + + +def assert_resp_response(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response == resp2_expected + else: + assert response == resp3_expected + + +def assert_resp_response_in(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response in resp2_expected + else: + assert response in resp3_expected diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a80fa30cb9..173f0fd1ab 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -31,6 +31,7 @@ from redis.parsers import AsyncCommandsParser from redis.utils import str_if_bytes from tests.conftest import ( + assert_resp_response, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -1613,7 +1614,8 @@ async def test_cluster_zdiff(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + response = await r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: @@ -1621,7 +1623,8 @@ async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert await r.zrange("{foo}out", 0, -1) == [b"a3"] - assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + response = await r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", b"3")], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zinter(self, r: RedisCluster) -> None: @@ -1635,32 +1638,41 @@ async def test_cluster_zinter(self, r: RedisCluster) -> None: ["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True ) # aggregate with SUM - assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) + assert_resp_response( + r, response, [(b"a3", b"8"), (b"a1", b"9")], [[b"a3", 8.0], [b"a1", 9.0]] + ) # aggregate with MAX - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] + ) + assert_resp_response( + r, response, [(b"a3", b"5"), (b"a1", b"6")], [[b"a3", 5.0], [b"a1", 6.0]] + ) # aggregate with MIN - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] + ) + assert_resp_response( + r, response, [(b"a1", b"1"), (b"a3", b"1")], [[b"a1", 1.0], [b"a3", 1.0]] + ) # with weights - assert await r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a3", 20), (b"a1", 23)] + res = await r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) + assert_resp_response( + r, res, [(b"a3", b"20"), (b"a1", b"23")], [[b"a3", 20.0], [b"a1", 23.0]] + ) async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1672,10 +1684,12 @@ async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1687,10 +1701,12 @@ async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1699,10 +1715,12 @@ async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: @@ -1767,10 +1785,12 @@ async def test_cluster_zrangestore(self, r: RedisCluster) -> None: assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert await r.zrangestore("{foo}b", "{foo}a", 1, 2) assert await r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert await r.zrange("{foo}b", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert await r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1797,36 +1817,49 @@ async def test_cluster_zunion(self, r: RedisCluster) -> None: b"a3", b"a1", ] - assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert await r.zunion( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + await r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) async def test_cluster_zunionstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1838,12 +1871,12 @@ async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1855,12 +1888,12 @@ async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1869,12 +1902,12 @@ async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") async def test_cluster_pfcount(self, r: RedisCluster) -> None: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 7f9364d557..1f6e0d9c74 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -127,27 +127,24 @@ async def test_acl_getuser_setuser(self, r_teardown): r = r_teardown(username) # test enabled=False assert await r.acl_setuser(username, enabled=False, reset=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": False, - "flags": ["off", "allchannels", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "off" in acl["flags"] + assert acl["enabled"] is False # test nopass=True assert await r.acl_setuser(username, enabled=True, reset=True, nopass=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": True, - "flags": ["on", "allchannels", "nopass", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "on" in acl["flags"] + assert "nopass" in acl["flags"] + assert acl["enabled"] is True # test all args assert await r.acl_setuser( @@ -164,8 +161,8 @@ async def test_acl_getuser_setuser(self, r_teardown): assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert set(acl["flags"]) == {"on", "allchannels", "sanitize-payload"} + assert acl["keys"] == [b"cache:*", b"objects:*"] assert len(acl["passwords"]) == 2 # test reset=False keeps existing ACL and applies new ACL on top @@ -191,7 +188,7 @@ async def test_acl_getuser_setuser(self, r_teardown): assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] + assert set(acl["flags"]) == {"on", "allchannels", "sanitize-payload"} assert set(acl["keys"]) == {b"cache:*", b"objects:*"} assert len(acl["passwords"]) == 2 diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8cd5cf6fba..33ae989faf 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -421,7 +421,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): ) assert expect in info.exconly() - +@pytest.mark.onlynoncluster class TestPubSubRESP3Handler: def my_handler(self, message): self.message = ["my handler", message] diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 4a43eaea21..f48995fb00 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -39,6 +39,7 @@ from .conftest import ( _get_client, + assert_resp_response, is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_lt, @@ -1750,49 +1751,42 @@ def test_cluster_zinter(self, r): # invalid aggregation with pytest.raises(DataError): r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - if is_resp2_connection(r): - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a3", 20), (b"a1", 23)] - else: - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - [b"a3", 8], - [b"a1", 9], - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [[b"a3", 5], [b"a1", 6]] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [[b"a1", 1], [b"a3", 1]] - # with weights - assert r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [[b"a3", 2], [b"a1", 2]] + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MAX"), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MIN"), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) + assert_resp_response( + r, + r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a3", 20.0), (b"a1", 23.0)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zinterstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1802,7 +1796,12 @@ def test_cluster_zinterstore_max(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zinterstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1812,14 +1811,24 @@ def test_cluster_zinterstore_min(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) def test_cluster_zinterstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmax(self, r): @@ -1852,7 +1861,12 @@ def test_cluster_zrangestore(self, r): assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("{foo}b", "{foo}a", 1, 2) assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}b", 0, 1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1874,39 +1888,45 @@ def test_cluster_zunion(self, r): r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zunionstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zunionstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1916,12 +1936,12 @@ def test_cluster_zunionstore_max(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zunionstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1931,24 +1951,24 @@ def test_cluster_zunionstore_min(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) def test_cluster_zunionstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") def test_cluster_pfcount(self, r): @@ -2970,7 +2990,12 @@ def test_pipeline_readonly(self, r): with r.pipeline() as readonly_pipe: readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) - assert readonly_pipe.execute() == [b"a1", [(b"z1", 1.0), (b"z2", 4)]] + assert_resp_response( + r, + readonly_pipe.execute(), + [b"a1", [(b"z1", 1.0), (b"z2", 4)]], + [b"a1", [[b"z1", 1.0], [b"z2", 4.0]]], + ) def test_moved_redirection_on_slave_with_default(self, r): """ diff --git a/tests/test_commands.py b/tests/test_commands.py index b638dcb039..75327e85a5 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -70,6 +70,7 @@ def test_case_insensitive_command_names(self, r): class TestRedisCommands: + @pytest.mark.onlynoncluster @skip_if_redis_enterprise() def test_auth(self, r, request): # sending an AUTH command before setting a user/password on the @@ -104,7 +105,6 @@ def teardown(): # connection field is not set in Redis Cluster, but that's ok # because the problem discussed above does not apply to Redis Cluster pass - r.auth(temp_pass) r.config_set("requirepass", "") r.acl_deluser(username) diff --git a/tests/test_function.py b/tests/test_function.py index 7ce66a38e6..bb32fdf27c 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -2,7 +2,7 @@ from redis.exceptions import ResponseError -from .conftest import skip_if_server_version_lt +from .conftest import assert_resp_response, skip_if_server_version_lt engine = "lua" lib = "mylib" @@ -64,12 +64,22 @@ def test_function_list(self, r): [[b"name", b"myfunc", b"description", None, b"flags", [b"no-writes"]]], ] ] - assert r.function_list() == res - assert r.function_list(library="*lib") == res - assert ( - r.function_list(withcode=True)[0][7] - == f"#!{engine} name={lib} \n {function}".encode() + resp3_res = [ + { + b"library_name": b"mylib", + b"engine": b"LUA", + b"functions": [ + {b"name": b"myfunc", b"description": None, b"flags": {b"no-writes"}} + ], + } + ] + assert_resp_response(r, r.function_list(), res, resp3_res) + assert_resp_response(r, r.function_list(library="*lib"), res, resp3_res) + res[0].extend( + [b"library_code", f"#!{engine} name={lib} \n {function}".encode()] ) + resp3_res[0][b"library_code"] = f"#!{engine} name={lib} \n {function}".encode() + assert_resp_response(r, r.function_list(withcode=True), res, resp3_res) @pytest.mark.onlycluster def test_function_list_on_cluster(self, r): From 7b87e20deace3fb84303c4af234357c6dc1f94aa Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 31 May 2023 23:35:42 +0300 Subject: [PATCH 10/16] push handler sharded pubsub --- redis/cluster.py | 4 +++- tests/test_pubsub.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index 0e7e631f33..01f696dc74 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1738,7 +1738,9 @@ def _get_node_pubsub(self, node): try: return self.node_pubsub_mapping[node.name] except KeyError: - pubsub = node.redis_connection.pubsub() + pubsub = node.redis_connection.pubsub( + push_handler_func=self.push_handler_func + ) self.node_pubsub_mapping[node.name] = pubsub return pubsub diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 2f6b4bad80..3da734df99 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -608,6 +608,18 @@ def test_push_handler(self, r): assert wait_for_message(p) is None assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + def test_push_handler_sharded_pubsub(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"ssubscribe", b"foo", 1]] + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"smessage", b"foo", b"test message"]] + class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" From dc7fa20ee492dd69f0412532fd10a33c51aba6d5 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 13:11:42 +0300 Subject: [PATCH 11/16] Use assert_resp_response wherever possible --- tests/conftest.py | 4 +- tests/test_asyncio/conftest.py | 23 - tests/test_asyncio/test_commands.py | 25 +- tests/test_asyncio/test_pubsub.py | 4 +- tests/test_cluster.py | 18 +- tests/test_commands.py | 827 +++++++++++++--------------- tests/test_pubsub.py | 1 + 7 files changed, 397 insertions(+), 505 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1b98107c74..faf14b9115 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,13 +477,13 @@ def wait_for_command(client, monitor, command, key=None): def is_resp2_connection(r): if isinstance(r, redis.Redis): protocol = r.connection_pool.connection_kwargs.get("protocol") - elif isinstance(r, redis.RedisCluster): + elif isinstance(r, redis.cluster.AbstractRedisCluster): protocol = r.nodes_manager.connection_kwargs.get("protocol") return protocol in ["2", 2, None] def get_protocol_version(r): - if isinstance(r, redis.Redis): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): return r.connection_pool.connection_kwargs.get("protocol") elif isinstance(r, redis.cluster.AbstractRedisCluster): return r.nodes_manager.connection_kwargs.get("protocol") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index e8ab6b297f..28a6f0626f 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -236,29 +236,6 @@ async def wait_for_command( return None -def get_protocol_version(r): - if isinstance(r, redis.Redis): - return r.connection_pool.connection_kwargs.get("protocol") - elif isinstance(r, redis.RedisCluster): - return r.nodes_manager.connection_kwargs.get("protocol") - - -def assert_resp_response(r, response, resp2_expected, resp3_expected): - protocol = get_protocol_version(r) - if protocol in [2, "2", None]: - assert response == resp2_expected - else: - assert response == resp3_expected - - -def assert_resp_response_in(r, response, resp2_expected, resp3_expected): - protocol = get_protocol_version(r) - if protocol in [2, "2", None]: - assert response in resp2_expected - else: - assert response in resp3_expected - - # python 3.6 doesn't have the asynccontextmanager decorator. Provide it here. class AsyncContextManager: def __init__(self, async_generator): diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 1f6e0d9c74..78376fd0e9 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -13,13 +13,14 @@ from redis import exceptions from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info from tests.conftest import ( + assert_resp_response, + assert_resp_response_in, + is_resp2_connection, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, ) -from .conftest import assert_resp_response, assert_resp_response_in - REDIS_6_VERSION = "5.9.0" @@ -2913,16 +2914,16 @@ async def test_xreadgroup(self, r: redis.Redis): # xreadgroup with noack does not have any items in the PEL await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") - # res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) - # empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) - # if is_resp2_connection(r): - # assert len(res[0][1]) == 2 - # # now there should be nothing pending - # assert len(empty_res[0][1]) == 0 - # else: - # assert len(res[strem_name][0]) == 2 - # # now there should be nothing pending - # assert len(empty_res[strem_name][0]) == 0 + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[strem_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[strem_name][0]) == 0 await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index db92cc8a95..8160b3b0f1 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -17,10 +17,9 @@ from redis.exceptions import ConnectionError from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import skip_if_server_version_lt +from tests.conftest import get_protocol_version, skip_if_server_version_lt from .compat import create_task, mock -from .conftest import get_protocol_version def with_timeout(t): @@ -421,6 +420,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): ) assert expect in info.exconly() + @pytest.mark.onlynoncluster class TestPubSubRESP3Handler: def my_handler(self, message): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index f48995fb00..2ca323eaf5 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -40,7 +40,6 @@ from .conftest import ( _get_client, assert_resp_response, - is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -1726,10 +1725,13 @@ def test_cluster_zdiff(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - if is_resp2_connection(r): - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] - else: - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [[b"a3", 3.0]] + response = r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response( + r, + response, + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): @@ -1737,10 +1739,8 @@ def test_cluster_zdiffstore(self, r): r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert r.zrange("{foo}out", 0, -1) == [b"a3"] - if is_resp2_connection(r): - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] - else: - assert r.zrange("{foo}out", 0, -1, withscores=True) == [[b"a3", 3.0]] + response = r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): diff --git a/tests/test_commands.py b/tests/test_commands.py index 75327e85a5..97fbb34925 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -13,6 +13,8 @@ from .conftest import ( _get_client, + assert_resp_response, + assert_resp_response_in, is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_gte, @@ -320,14 +322,12 @@ def teardown(): assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 assert set(acl["channels"]) == {"&message:*"} - if is_resp2_connection(r): - assert acl["selectors"] == [ - ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] - ] - else: - assert acl["selectors"] == [ - {"commands": "-@all +set", "keys": "%W~app*", "channels": ""} - ] + assert_resp_response( + r, + acl["selectors"], + ["commands", "-@all +set", "keys", "%W~app*", "channels", ""], + [{"commands": "-@all +set", "keys": "%W~app*", "channels": ""}], + ) @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): @@ -389,11 +389,13 @@ def teardown(): assert len(r.acl_log()) == 2 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - if is_resp2_connection(r): - assert "client-info" in r.acl_log(count=1)[0] - else: - assert "client-info" in r.acl_log(count=1)[0].keys() - assert r.acl_log_reset() + expected = r.acl_log(count=1)[0] + assert_resp_response_in( + r, + "client-info", + expected, + expected.keys(), + ) @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -1132,11 +1134,12 @@ def test_lcs(self, r): r.mset({"foo": "ohmytext", "bar": "mynewtext"}) assert r.lcs("foo", "bar") == b"mytext" assert r.lcs("foo", "bar", len=True) == 6 - if is_resp2_connection(r): - result = [b"matches", [[[4, 7], [5, 8]]], b"len", 6] - else: - result = {b"matches": [[[4, 7], [5, 8]]], b"len": 6} - assert r.lcs("foo", "bar", idx=True, minmatchlen=3) == result + assert_resp_response( + r, + r.lcs("foo", "bar", idx=True, minmatchlen=3), + [b"matches", [[[4, 7], [5, 8]]], b"len", 6], + {b"matches": [[[4, 7], [5, 8]]], b"len": 6}, + ) with pytest.raises(redis.ResponseError): assert r.lcs("foo", "bar", len=True, idx=True) @@ -1550,10 +1553,7 @@ def test_hrandfield(self, r): assert r.hrandfield("key") is not None assert len(r.hrandfield("key", 2)) == 2 # with values - if is_resp2_connection(r): - assert len(r.hrandfield("key", 2, True)) == 4 - else: - assert len(r.hrandfield("key", 2, True)) == 2 + assert_resp_response(r, len(r.hrandfield("key", 2, withvalues=True)), 4, 2) # without duplications assert len(r.hrandfield("key", 10)) == 5 # with duplications @@ -1706,30 +1706,26 @@ def test_stralgo_lcs(self, r): assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels assert r.stralgo("LCS", value1, value2, len=True) == len(res) - if is_resp2_connection(r): - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} - else: - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]} + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True), + {"len": len(res), "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]]}, + ) + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]]}, + ) + assert_resp_response( + r, + r.stralgo( + "LCS", value1, value2, idx=True, withmatchlen=True, minmatchlen=4 + ), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]}, + ) @skip_if_server_version_lt("6.0.0") @skip_if_server_version_gte("7.0.0") @@ -2178,10 +2174,12 @@ def test_spop_multi_value(self, r): for value in values: assert value in s - if is_resp2_connection(r): - assert r.spop("a", 1) == list(set(s) - set(values)) - else: - assert r.spop("a", 1) == set(s) - set(values) + assert_resp_response( + r, + r.spop("a", 1), + list(set(s) - set(values)), + set(s) - set(values), + ) def test_srandmember(self, r): s = [b"1", b"2", b"3"] @@ -2232,18 +2230,12 @@ def test_script_debug(self, r): def test_zadd(self, r): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} r.zadd("a", mapping) - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] - else: - assert r.zrange("a", 0, -1, withscores=True) == [ - [b"a1", 1.0], - [b"a2", 2.0], - [b"a3", 3.0], - ] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -2260,32 +2252,32 @@ def test_zadd(self, r): def test_zadd_nx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - else: - assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) def test_zadd_xx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] - else: - assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 99.0]] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 99.0)], + [[b"a1", 99.0]], + ) def test_zadd_ch(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 99.0), - ] - else: - assert r.zrange("a", 0, -1, withscores=True) == [ - [b"a2", 2.0], - [b"a1", 99.0], - ] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a2", 2.0), (b"a1", 99.0)], + [[b"a2", 2.0], [b"a1", 99.0]], + ) def test_zadd_incr(self, r): assert r.zadd("a", {"a1": 1}) == 1 @@ -2333,10 +2325,12 @@ def test_zdiff(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiff(["a", "b"]) == [b"a3"] - if is_resp2_connection(r): - assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] - else: - assert r.zdiff(["a", "b"], withscores=True) == [[b"a3", 3.0]] + assert_resp_response( + r, + r.zdiff(["a", "b"], withscores=True), + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @@ -2345,10 +2339,12 @@ def test_zdiffstore(self, r): r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiffstore("out", ["a", "b"]) assert r.zrange("out", 0, -1) == [b"a3"] - if is_resp2_connection(r): - assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] - else: - assert r.zrange("out", 0, -1, withscores=True) == [[b"a3", 3.0]] + assert_resp_response( + r, + r.zrange("out", 0, -1, withscores=True), + [(b"a3", 3.0)], + [[b"a3", 3.0]], + ) def test_zincrby(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2373,48 +2369,34 @@ def test_zinter(self, r): # invalid aggregation with pytest.raises(exceptions.DataError): r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) - if is_resp2_connection(r): - # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] - # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a3", 1), - ] - # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] - else: - # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [ - [b"a3", 8], - [b"a1", 9], - ] - # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - [b"a3", 5], - [b"a1", 6], - ] - # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - [b"a1", 1], - [b"a3", 1], - ] - # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - [b"a3", 20], - [b"a1", 23], - ] + # aggregate with SUM + assert_resp_response( + r, + r.zinter(["a", "b", "c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) + # aggregate with MAX + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) + # aggregate with MIN + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) + # with weights + assert_resp_response( + r, + r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -2431,10 +2413,12 @@ def test_zinterstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"]) == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 8], [b"a1", 9]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): @@ -2442,10 +2426,12 @@ def test_zinterstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 5], [b"a1", 6]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): @@ -2453,10 +2439,12 @@ def test_zinterstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a1", 1], [b"a3", 3]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1], [b"a3", 3]], + ) @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): @@ -2464,34 +2452,36 @@ def test_zinterstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 20], [b"a1", 23]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - if is_resp2_connection(r): - assert r.zpopmax("a") == [(b"a3", 3)] - # with count - assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] - else: - assert r.zpopmax("a") == [b"a3", 3.0] - # with count - assert r.zpopmax("a", count=2) == [[b"a2", 2], [b"a1", 1]] + assert_resp_response(r, r.zpopmax("a"), [(b"a3", 3)], [b"a3", 3.0]) + # with count + assert_resp_response( + r, + r.zpopmax("a", count=2), + [(b"a2", 2), (b"a1", 1)], + [[b"a2", 2], [b"a1", 1]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - if is_resp2_connection(r): - assert r.zpopmin("a") == [(b"a1", 1)] - # with count - assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] - else: - assert r.zpopmin("a") == [b"a1", 1.0] - # with count - assert r.zpopmin("a", count=2) == [[b"a2", 2], [b"a3", 3]] + assert_resp_response(r, r.zpopmin("a"), [(b"a1", 1)], [b"a1", 1.0]) + # with count + assert_resp_response( + r, + r.zpopmin("a", count=2), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): @@ -2499,10 +2489,12 @@ def test_zrandemember(self, r): assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 # with scores - if is_resp2_connection(r): - assert len(r.zrandmember("a", 2, True)) == 4 - else: - assert len(r.zrandmember("a", 2, True)) == 2 + assert_resp_response( + r, + len(r.zrandmember("a", 2, withscores=True)), + 4, + 2, + ) # without duplications assert len(r.zrandmember("a", 10)) == 5 # with duplications @@ -2538,37 +2530,41 @@ def test_bzpopmin(self, r): @skip_if_server_version_lt("7.0.0") def test_zmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - if is_resp2_connection(r): - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - else: - res = [b"a", [[b"a1", 1.0], [b"a2", 2.0]]] - assert r.zmpop("2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.zmpop("2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - if is_resp2_connection(r): - res = [b"b", [[b"b1", b"10"]]] - else: - res = [b"b", [[b"b1", 10.0]]] - assert r.zmpop("2", ["b", "a"], max=True) == res + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_bzmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - if is_resp2_connection(r): - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - else: - res = [b"a", [[b"a1", 1.0], [b"a2", 2.0]]] - assert r.bzmpop(1, "2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.bzmpop(1, "2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.bzmpop(1, "2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - if is_resp2_connection(r): - res = [b"b", [[b"b1", b"10"]]] - else: - res = [b"b", [[b"b1", 10.0]]] - assert r.bzmpop(0, "2", ["b", "a"], max=True) == res + assert_resp_response( + r, + r.bzmpop(0, "2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) assert r.bzmpop(1, "2", ["foo", "bar"], max=True) is None def test_zrange(self, r): @@ -2579,18 +2575,24 @@ def test_zrange(self, r): assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - if is_resp2_connection(r): - assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] + assert_resp_response( + r, + r.zrange("a", 0, 1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) - # custom score function - assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] - else: - assert r.zrange("a", 0, 1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] - assert r.zrange("a", 1, 2, withscores=True) == [[b"a2", 2.0], [b"a3", 3.0]] + # # custom score function + # assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -2622,25 +2624,20 @@ def test_zrange_params(self, r): b"a3", b"a2", ] - if is_resp2_connection(r): - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrange( - "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] - - else: - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - [b"a2", 2.0], - [b"a3", 3.0], - [b"a4", 4.0], - ] - assert r.zrange( + assert_resp_response( + r, + r.zrange("a", 2, 4, byscore=True, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrange( "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [[b"a4", 4], [b"a3", 3], [b"a2", 2]] + ), + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) # rev assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @@ -2653,10 +2650,12 @@ def test_zrangestore(self, r): assert r.zrange("b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("b", "a", 1, 2) assert r.zrange("b", 0, -1) == [b"a2", b"a3"] - if is_resp2_connection(r): - assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] - else: - assert r.zrange("b", 0, -1, withscores=True) == [[b"a2", 2], [b"a3", 3]] + assert_resp_response( + r, + r.zrange("b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) # reversed order assert r.zrangestore("b", "a", 1, 2, desc=True) assert r.zrange("b", 0, -1) == [b"a1", b"a2"] @@ -2691,28 +2690,18 @@ def test_zrangebyscore(self, r): # slicing with start/num assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - if is_resp2_connection(r): - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] - else: - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - [b"a2", 2.0], - [b"a3", 3.0], - [b"a4", 4.0], - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - [b"a2", 2], - [b"a3", 3], - [b"a4", 4], - ] + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int), + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2759,32 +2748,25 @@ def test_zrevrange(self, r): assert r.zrevrange("a", 0, 1) == [b"a3", b"a2"] assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] - if is_resp2_connection(r): - # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - assert r.zrevrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 1.0), - ] + # withscores + assert_resp_response( + r, + r.zrevrange("a", 0, 1, withscores=True), + [(b"a3", 3.0), (b"a2", 2.0)], + [[b"a3", 3.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrevrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a1", 1.0)], + [[b"a2", 2.0], [b"a1", 1.0]], + ) - # custom score function - assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - else: - # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [ - [b"a3", 3.0], - [b"a2", 2.0], - ] - assert r.zrevrange("a", 1, 2, withscores=True) == [ - [b"a2", 2.0], - [b"a1", 1.0], - ] + # # custom score function + # assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a3", 3.0), + # (b"a2", 2.0), + # ] def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2792,28 +2774,20 @@ def test_zrevrangebyscore(self, r): # slicing with start/num assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] - if is_resp2_connection(r): - # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] - # custom score function - assert r.zrevrangebyscore( - "a", 4, 2, withscores=True, score_cast_func=int - ) == [ - (b"a4", 4), - (b"a3", 3), - (b"a2", 2), - ] - else: - # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - [b"a4", 4.0], - [b"a3", 3.0], - [b"a2", 2.0], - ] + # withscores + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) + # custom score function + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2835,63 +2809,33 @@ def test_zunion(self, r): r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] - - if is_resp2_connection(r): - assert r.zunion(["a", "b", "c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a2", 1), - (b"a3", 1), - (b"a4", 4), - ] - # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - else: - assert r.zunion(["a", "b", "c"], withscores=True) == [ - [b"a2", 3], - [b"a4", 4], - [b"a3", 8], - [b"a1", 9], - ] - # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - [b"a2", 2], - [b"a4", 4], - [b"a3", 5], - [b"a1", 6], - ] - # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - [b"a1", 1], - [b"a2", 1], - [b"a3", 1], - [b"a4", 4], - ] - # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - [b"a2", 5], - [b"a4", 12], - [b"a3", 20], - [b"a1", 23], - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) + # max + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) + # min + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1], [b"a2", 1], [b"a3", 1], [b"a4", 4]], + ) + # with weight + assert_resp_response( + r, + r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): @@ -2899,21 +2843,12 @@ def test_zunionstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"]) == 4 - - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a2", 3], - [b"a4", 4], - [b"a3", 8], - [b"a1", 9], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): @@ -2921,20 +2856,12 @@ def test_zunionstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a2", 2], - [b"a4", 4], - [b"a3", 5], - [b"a1", 6], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): @@ -2942,20 +2869,12 @@ def test_zunionstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a1", 1], - [b"a2", 2], - [b"a3", 3], - [b"a4", 4], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1], [b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): @@ -2963,20 +2882,12 @@ def test_zunionstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a2", 5], - [b"a4", 12], - [b"a3", 20], - [b"a1", 23], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): @@ -4351,10 +4262,12 @@ def test_xinfo_stream_full(self, r): info = r.xinfo_stream(stream, full=True) assert info["length"] == 1 - if is_resp2_connection(r): - assert m1 in info["entries"] - else: - assert m1 in info["entries"].keys() + assert_resp_response_in( + r, + m1, + info["entries"], + info["entries"].keys(), + ) assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4495,40 +4408,39 @@ def test_xread(self, r): m1 = r.xadd(stream, {"foo": "bar"}) m2 = r.xadd(stream, {"bing": "baz"}) - strem_name = stream.encode() + stream_name = stream.encode() expected_entries = [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - res = r.xread(streams={stream: 0}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xread(streams={stream: 0}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) expected_entries = [get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - res = r.xread(streams={stream: 0}, count=1) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xread(streams={stream: 0}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) expected_entries = [get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - res = r.xread(streams={stream: m1}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xread(streams={stream: m1}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) # xread starting at the last message returns an empty list - res = r.xread(streams={stream: m2}) - if is_resp2_connection(r): - assert res == [] - else: - assert res == {} + assert_resp_response(r, r.xread(streams={stream: m2}), [], {}) @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): @@ -4539,18 +4451,19 @@ def test_xreadgroup(self, r): m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) - strem_name = stream.encode() + stream_name = stream.encode() expected_entries = [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - res = r.xreadgroup(group, consumer, streams={stream: ">"}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) @@ -4558,11 +4471,12 @@ def test_xreadgroup(self, r): expected_entries = [get_stream_message(r, stream, m1)] # xread with count=1 returns only the first message - res = r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) @@ -4571,10 +4485,9 @@ def test_xreadgroup(self, r): r.xgroup_create(stream, group, "$") # xread starting after the last message returns an empty message list - if is_resp2_connection(r): - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == [] - else: - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == {} + assert_resp_response( + r, r.xreadgroup(group, consumer, streams={stream: ">"}), [], {} + ) # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) @@ -4586,9 +4499,9 @@ def test_xreadgroup(self, r): # now there should be nothing pending assert len(empty_res[0][1]) == 0 else: - assert len(res[strem_name][0]) == 2 + assert len(res[stream_name][0]) == 2 # now there should be nothing pending - assert len(empty_res[strem_name][0]) == 0 + assert len(empty_res[stream_name][0]) == 0 r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") @@ -4596,11 +4509,12 @@ def test_xreadgroup(self, r): expected_entries = [(m1, {}), (m2, {})] r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - res = r.xreadgroup(group, consumer, streams={stream: "0"}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: "0"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): @@ -4893,18 +4807,17 @@ def test_command(self, r): @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_command_getkeysandflags(self, r: redis.Redis): - if is_resp2_connection(r): - res = [ + assert_resp_response( + r, + r.command_getkeysandflags("LMOVE", "mylist1", "mylist2", "left", "left"), + [ [b"mylist1", [b"RW", b"access", b"delete"]], [b"mylist2", [b"RW", b"insert"]], - ] - else: - res = [ + ], + [ [b"mylist1", {b"RW", b"access", b"delete"}], [b"mylist2", {b"RW", b"insert"}], - ] - assert res == r.command_getkeysandflags( - "LMOVE", "mylist1", "mylist2", "left", "left" + ], ) @pytest.mark.onlynoncluster diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 3da734df99..fc98966d74 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -609,6 +609,7 @@ def test_push_handler(self, r): assert self.message == ["my handler", [b"message", b"foo", b"test message"]] @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @skip_if_server_version_lt("7.0.0") def test_push_handler_sharded_pubsub(self, r): if is_resp2_connection(r): return From 8ac26e913bc1213e986aa4291cad3899915203fd Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 13:27:20 +0300 Subject: [PATCH 12/16] fix test_xreadgroup --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index faf14b9115..6454750353 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -475,7 +475,7 @@ def wait_for_command(client, monitor, command, key=None): def is_resp2_connection(r): - if isinstance(r, redis.Redis): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): protocol = r.connection_pool.connection_kwargs.get("protocol") elif isinstance(r, redis.cluster.AbstractRedisCluster): protocol = r.nodes_manager.connection_kwargs.get("protocol") From a3476c6ef9deca67add978d533efc2dc31136e6d Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 13:34:17 +0300 Subject: [PATCH 13/16] fix cluster_zdiffstore and cluster_zinter --- tests/test_asyncio/test_cluster.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 173f0fd1ab..58c0e0b0c7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1624,7 +1624,7 @@ async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert await r.zrange("{foo}out", 0, -1) == [b"a3"] response = await r.zrange("{foo}out", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a3", b"3")], [[b"a3", 3.0]]) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zinter(self, r: RedisCluster) -> None: @@ -1640,26 +1640,26 @@ async def test_cluster_zinter(self, r: RedisCluster) -> None: # aggregate with SUM response = await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) assert_resp_response( - r, response, [(b"a3", b"8"), (b"a1", b"9")], [[b"a3", 8.0], [b"a1", 9.0]] + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8], [b"a1", 9]] ) # aggregate with MAX response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True ) assert_resp_response( - r, response, [(b"a3", b"5"), (b"a1", b"6")], [[b"a3", 5.0], [b"a1", 6.0]] + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] ) # aggregate with MIN response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True ) assert_resp_response( - r, response, [(b"a1", b"1"), (b"a3", b"1")], [[b"a1", 1.0], [b"a3", 1.0]] + r, response, [(b"a1", 1), (b"a3", 1)], [[b"a1", 1], [b"a3", 1]] ) # with weights res = await r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) assert_resp_response( - r, res, [(b"a3", b"20"), (b"a1", b"23")], [[b"a3", 20.0], [b"a1", 23.0]] + r, res, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] ) async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: From 6965b41088fd1644e7cc66963e1810cf345461e8 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 15:08:28 +0300 Subject: [PATCH 14/16] fix review comments --- redis/asyncio/client.py | 2 +- redis/asyncio/connection.py | 2 +- redis/cluster.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index be70143504..18fdf94174 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -671,7 +671,7 @@ def __init__( if self.encoder is None: self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: - self.health_check_response: Iterable[Union[str, bytes]] = [ + self.health_check_response = [ ["pong", self.HEALTH_CHECK_MESSAGE], self.HEALTH_CHECK_MESSAGE, ] diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ae4d67b49a..470eda1a15 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -345,7 +345,7 @@ async def on_connect(self) -> None: auth_args = cred_provider.get_credentials() # if resp version is specified and we have auth args, # we need to send them via HELLO - if auth_args and self.protocol != 2: + if auth_args and self.protocol not in [2, "2"]: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) # update cluster exception classes diff --git a/redis/cluster.py b/redis/cluster.py index 01f696dc74..898db29cdc 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1642,10 +1642,10 @@ def __init__( self.node_pubsub_mapping = {} self._pubsubs_generator = self._pubsubs_generator() super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder, push_handler_func=push_handler_func, + **kwargs, ) def set_pubsub_node(self, cluster, node=None, host=None, port=None): From 599bdc013a5061ad2843ccac67128ae128212f5c Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 15:09:44 +0300 Subject: [PATCH 15/16] fix review comments --- redis/asyncio/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 470eda1a15..ee359821db 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -355,9 +355,9 @@ async def on_connect(self) -> None: auth_args = ["default", auth_args[0]] await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) response = await self.read_response() - if response.get(b"proto") != int(self.protocol) and response.get( + if response.get(b"proto") not in [2, "2"] and response.get( "proto" - ) != int(self.protocol): + ) not in [2, "2"]: raise ConnectionError("Invalid RESP version") # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH From 4e08e631d0bee57792a36dbb7717a33e58ad07f5 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 15:12:39 +0300 Subject: [PATCH 16/16] linters --- redis/asyncio/connection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ee359821db..b51e4fd8ce 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -355,9 +355,10 @@ async def on_connect(self) -> None: auth_args = ["default", auth_args[0]] await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) response = await self.read_response() - if response.get(b"proto") not in [2, "2"] and response.get( - "proto" - ) not in [2, "2"]: + if response.get(b"proto") not in [2, "2"] and response.get("proto") not in [ + 2, + "2", + ]: raise ConnectionError("Invalid RESP version") # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH