From 2e46613dcdf9db7f948f3258fee1c9199da8cbfa Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Fri, 27 Sep 2024 14:16:16 +0300 Subject: [PATCH] Client side caching refactoring (#3350) * Restructure client side caching code Right now the client side caching code is implemented mostly on the level of Connections, which is too low. We need to have a shared cache across several connections. Move the cache implementation higher, while trying to encapsulate it better, into a `CacheMixin` class. This is work in progress, many details still need to be taken care of! * Temporary refactor * Finished CacheProxyConnection implementation, added comments * Added test cases and scheduler dependency * Added support for RedisCluster and multi-threaded test cases * Added support for BlockingConnectionPool * Fixed docker-compose command * Revert port changes * Initial take on Sentinel support * Remove keys option after usage * Added condition to remove keys entry on async * Added same keys entry removal in pipeline * Added caching support for Sentinel * Added locking when accesing cache object * Rmoved keys option from options * Removed redundant entities * Added cache support for SSLConnection * Moved ssl argument handling to cover cluster case * Revert local test changes * Fixed bug with missing async operator * Revert accidental changes * Added API to return cache object * Added eviction policy configuration * Added mark to skip test on cluster * Removed test case that makes no sense * Skip tests in RESP2 * Added scheduler to dev_requirements * Codestyle changes * Fixed characters per line restriction * Fixed line length * Removed blank lines in imports * Fixed imports codestyle * Added CacheInterface abstraction * Removed redundant references * Moved hardcoded values to constants, restricted dependency versions * Changed defaults to correct values * Added custom background scheduler, added unit testing * Codestyle changes * Updated RESP2 restriction * Cahnged typing to more generic * Restrict pytest-asyncio version to 0.23 * Added upper version limit * Removed usntable multithreaded tests * Removed more flacky multithreaded tests * Fixed issue with Sentinel killing healthcheck thread before execution * Removed cachetools dependency, added custom cache implementation * Updated test cases * Updated typings * Updated types * Revert changes * Removed use_cache, make health_check configurable, removed retry logic around can_read() * Revert test skip * Added documentation and codestyle fixes * Updated excluded wordlist * Added health_check thread cancelling in BlockingPool * Revert argument rename, extended documentation * Updated NodesManager to create shared cache between all nodes * Codestyle fixes * Updated docs * Added version restrictions * Added missing property getter * Updated Redis server version * Skip on long exception message * Removed keys entry as it's csc specific * Updated exception message for CSC * Updated condition by adding server name check * Added test coverage for decoded responses * Codestyle changes * Removed background healthcheck, use connection reference approach instead * Removed unused imports * Fixed broken tests * Codestyle changes * Fixed additional broken tests * Codestyle changes * Increased timer to avoid flackiness * Restrict tests cause of PyPy * Codestyle changes * Updated docs, convert getters function to properties, added dataclasses --------- Co-authored-by: Gabriel Erzse --- .github/wordlist.txt | 1 + .github/workflows/integration.yaml | 2 +- dev_requirements.txt | 1 - docs/examples/connection_examples.ipynb | 2 - docs/resp3_features.rst | 32 + redis/_cache.py | 385 ------ redis/_parsers/resp3.py | 48 +- redis/asyncio/client.py | 75 +- redis/asyncio/cluster.py | 86 +- redis/asyncio/connection.py | 119 +- redis/asyncio/sentinel.py | 1 - redis/cache.py | 401 ++++++ redis/client.py | 89 +- redis/cluster.py | 103 +- redis/commands/core.py | 6 +- redis/connection.py | 465 +++++-- redis/sentinel.py | 11 +- redis/utils.py | 39 + requirements.txt | 2 +- tests/conftest.py | 67 +- tests/test_asyncio/conftest.py | 1 - tests/test_asyncio/test_cache.py | 408 ------ tests/test_asyncio/test_cluster.py | 2 - tests/test_asyncio/test_connection.py | 1 - tests/test_asyncio/test_hash.py | 2 +- tests/test_asyncio/test_pubsub.py | 2 +- tests/test_cache.py | 1500 ++++++++++++++++------- tests/test_cluster.py | 11 +- tests/test_connection.py | 221 +++- tests/test_utils.py | 27 + 30 files changed, 2344 insertions(+), 1766 deletions(-) delete mode 100644 redis/_cache.py create mode 100644 redis/cache.py delete mode 100644 tests/test_asyncio/test_cache.py create mode 100644 tests/test_utils.py diff --git a/.github/wordlist.txt b/.github/wordlist.txt index ca2102b825..3ea543748e 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -1,6 +1,7 @@ APM ARGV BFCommands +CacheImpl CFCommands CMSCommands ClusterNode diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index c4da3bf3aa..b10edf2fb4 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -27,7 +27,7 @@ env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon - REDIS_IMAGE: redis:7.4-rc2 + REDIS_IMAGE: redis:latest REDIS_STACK_IMAGE: redis/redis-stack-server:latest jobs: diff --git a/dev_requirements.txt b/dev_requirements.txt index 931784cdaf..37a107d16d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,4 @@ black==24.3.0 -cachetools click==8.0.4 flake8-isort flake8 diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index cddded2865..fd60e2a495 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -69,9 +69,7 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "### By default this library uses the RESP 2 protocol. To enable RESP3, set protocol=3." ] diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst index 11c01985a0..326495b775 100644 --- a/docs/resp3_features.rst +++ b/docs/resp3_features.rst @@ -67,3 +67,35 @@ This means that should you want to perform something, on a given push notificati >> p = r.pubsub(push_handler_func=our_func) In the example above, upon receipt of a push notification, rather than log the message, in the case where specific text occurs, an IOError is raised. This example, highlights how one could start implementing a customized message handler. + +Client-side caching +------------------- + +Client-side caching is a technique used to create high performance services. +It utilizes the memory on application servers, typically separate from the database nodes, to cache a subset of the data directly on the application side. +For more information please check `official Redis documentation `_. +Please notice that this feature only available with RESP3 protocol enabled in sync client only. Supported in standalone, Cluster and Sentinel clients. + +Basic usage: + +Enable caching with default configuration: + +.. code:: python + + >>> import redis + >>> from redis.cache import CacheConfig + >>> r = redis.Redis(host='localhost', port=6379, protocol=3, cache_config=CacheConfig()) + +The same interface applies to Redis Cluster and Sentinel. + +Enable caching with custom cache implementation: + +.. code:: python + + >>> import redis + >>> from foo.bar import CacheImpl + >>> r = redis.Redis(host='localhost', port=6379, protocol=3, cache=CacheImpl()) + +CacheImpl should implement a `CacheInterface` specified in `redis.cache` package. + +More comprehensive documentation soon will be available at `official Redis documentation `_. diff --git a/redis/_cache.py b/redis/_cache.py deleted file mode 100644 index 90288383d6..0000000000 --- a/redis/_cache.py +++ /dev/null @@ -1,385 +0,0 @@ -import copy -import random -import time -from abc import ABC, abstractmethod -from collections import OrderedDict, defaultdict -from enum import Enum -from typing import List, Sequence, Union - -from redis.typing import KeyT, ResponseT - - -class EvictionPolicy(Enum): - LRU = "lru" - LFU = "lfu" - RANDOM = "random" - - -DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU - -DEFAULT_DENY_LIST = [ - "BF.CARD", - "BF.DEBUG", - "BF.EXISTS", - "BF.INFO", - "BF.MEXISTS", - "BF.SCANDUMP", - "CF.COMPACT", - "CF.COUNT", - "CF.DEBUG", - "CF.EXISTS", - "CF.INFO", - "CF.MEXISTS", - "CF.SCANDUMP", - "CMS.INFO", - "CMS.QUERY", - "DUMP", - "EXPIRETIME", - "FT.AGGREGATE", - "FT.ALIASADD", - "FT.ALIASDEL", - "FT.ALIASUPDATE", - "FT.CURSOR", - "FT.EXPLAIN", - "FT.EXPLAINCLI", - "FT.GET", - "FT.INFO", - "FT.MGET", - "FT.PROFILE", - "FT.SEARCH", - "FT.SPELLCHECK", - "FT.SUGGET", - "FT.SUGLEN", - "FT.SYNDUMP", - "FT.TAGVALS", - "FT._ALIASADDIFNX", - "FT._ALIASDELIFX", - "HRANDFIELD", - "JSON.DEBUG", - "PEXPIRETIME", - "PFCOUNT", - "PTTL", - "SRANDMEMBER", - "TDIGEST.BYRANK", - "TDIGEST.BYREVRANK", - "TDIGEST.CDF", - "TDIGEST.INFO", - "TDIGEST.MAX", - "TDIGEST.MIN", - "TDIGEST.QUANTILE", - "TDIGEST.RANK", - "TDIGEST.REVRANK", - "TDIGEST.TRIMMED_MEAN", - "TOPK.INFO", - "TOPK.LIST", - "TOPK.QUERY", - "TOUCH", - "TTL", -] - -DEFAULT_ALLOW_LIST = [ - "BITCOUNT", - "BITFIELD_RO", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUSBYMEMBER_RO", - "GEORADIUS_RO", - "GEOSEARCH", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "JSON.ARRINDEX", - "JSON.ARRLEN", - "JSON.GET", - "JSON.MGET", - "JSON.OBJKEYS", - "JSON.OBJLEN", - "JSON.RESP", - "JSON.STRLEN", - "JSON.TYPE", - "LCS", - "LINDEX", - "LLEN", - "LPOS", - "LRANGE", - "MGET", - "SCARD", - "SDIFF", - "SINTER", - "SINTERCARD", - "SISMEMBER", - "SMEMBERS", - "SMISMEMBER", - "SORT_RO", - "STRLEN", - "SUBSTR", - "SUNION", - "TS.GET", - "TS.INFO", - "TS.RANGE", - "TS.REVRANGE", - "TYPE", - "XLEN", - "XPENDING", - "XRANGE", - "XREAD", - "XREVRANGE", - "ZCARD", - "ZCOUNT", - "ZDIFF", - "ZINTER", - "ZINTERCARD", - "ZLEXCOUNT", - "ZMSCORE", - "ZRANGE", - "ZRANGEBYLEX", - "ZRANGEBYSCORE", - "ZRANK", - "ZREVRANGE", - "ZREVRANGEBYLEX", - "ZREVRANGEBYSCORE", - "ZREVRANK", - "ZSCORE", - "ZUNION", -] - -_RESPONSE = "response" -_KEYS = "keys" -_CTIME = "ctime" -_ACCESS_COUNT = "access_count" - - -class AbstractCache(ABC): - """ - An abstract base class for client caching implementations. - If you want to implement your own cache you must support these methods. - """ - - @abstractmethod - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - pass - - @abstractmethod - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - pass - - @abstractmethod - def delete_command(self, command: Union[str, Sequence[str]]): - pass - - @abstractmethod - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - pass - - @abstractmethod - def flush(self): - pass - - @abstractmethod - def invalidate_key(self, key: KeyT): - pass - - -class _LocalCache(AbstractCache): - """ - A caching mechanism for storing redis commands and their responses. - - Args: - max_size (int): The maximum number of commands to be stored in the cache. - ttl (int): The time-to-live for each command in seconds. - eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full. - - Attributes: - max_size (int): The maximum number of commands to be stored in the cache. - ttl (int): The time-to-live for each command in seconds. - eviction_policy (EvictionPolicy): The eviction policy used for cache management. - cache (OrderedDict): The ordered dictionary to store commands and their metadata. - key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key. - commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa - """ - - def __init__( - self, - max_size: int = 10000, - ttl: int = 0, - eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - ): - self.max_size = max_size - self.ttl = ttl - self.eviction_policy = eviction_policy - self.cache = OrderedDict() - self.key_commands_map = defaultdict(set) - self.commands_ttl_list = [] - - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - """ - Set a redis command and its response in the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command. - response (ResponseT): The response associated with the command. - keys_in_command (List[KeyT]): The list of keys used in the command. - """ - if len(self.cache) >= self.max_size: - self._evict() - self.cache[command] = { - _RESPONSE: response, - _KEYS: keys_in_command, - _CTIME: time.monotonic(), - _ACCESS_COUNT: 0, # Used only for LFU - } - self._update_key_commands_map(keys_in_command, command) - self.commands_ttl_list.append(command) - - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - """ - Get the response for a redis command from the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command. - - Returns: - ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa - """ - if command in self.cache: - if self._is_expired(command): - self.delete_command(command) - return - self._update_access(command) - return copy.deepcopy(self.cache[command]["response"]) - - def delete_command(self, command: Union[str, Sequence[str]]): - """ - Delete a redis command and its metadata from the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command to be deleted. - """ - if command in self.cache: - keys_in_command = self.cache[command].get("keys") - self._del_key_commands_map(keys_in_command, command) - self.commands_ttl_list.remove(command) - del self.cache[command] - - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - """ - Delete multiple commands and their metadata from the cache. - - Args: - commands (List[Union[str, Sequence[str]]]): The list of commands to be - deleted. - """ - for command in commands: - self.delete_command(command) - - def flush(self): - """Clear the entire cache, removing all redis commands and metadata.""" - self.cache.clear() - self.key_commands_map.clear() - self.commands_ttl_list = [] - - def _is_expired(self, command: Union[str, Sequence[str]]) -> bool: - """ - Check if a redis command has expired based on its time-to-live. - - Args: - command (Union[str, Sequence[str]]): The redis command. - - Returns: - bool: True if the command has expired, False otherwise. - """ - if self.ttl == 0: - return False - return time.monotonic() - self.cache[command]["ctime"] > self.ttl - - def _update_access(self, command: Union[str, Sequence[str]]): - """ - Update the access information for a redis command based on the eviction policy. - - Args: - command (Union[str, Sequence[str]]): The redis command. - """ - if self.eviction_policy == EvictionPolicy.LRU: - self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.LFU: - self.cache[command]["access_count"] = ( - self.cache.get(command, {}).get("access_count", 0) + 1 - ) - self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.RANDOM: - pass # Random eviction doesn't require updates - - def _evict(self): - """Evict a redis command from the cache based on the eviction policy.""" - if self._is_expired(self.commands_ttl_list[0]): - self.delete_command(self.commands_ttl_list[0]) - elif self.eviction_policy == EvictionPolicy.LRU: - self.cache.popitem(last=False) - elif self.eviction_policy == EvictionPolicy.LFU: - min_access_command = min( - self.cache, key=lambda k: self.cache[k].get("access_count", 0) - ) - self.cache.pop(min_access_command) - elif self.eviction_policy == EvictionPolicy.RANDOM: - random_command = random.choice(list(self.cache.keys())) - self.cache.pop(random_command) - - def _update_key_commands_map( - self, keys: List[KeyT], command: Union[str, Sequence[str]] - ): - """ - Update the key_commands_map with command that uses the keys. - - Args: - keys (List[KeyT]): The list of keys used in the command. - command (Union[str, Sequence[str]]): The redis command. - """ - for key in keys: - self.key_commands_map[key].add(command) - - def _del_key_commands_map( - self, keys: List[KeyT], command: Union[str, Sequence[str]] - ): - """ - Remove a redis command from the key_commands_map. - - Args: - keys (List[KeyT]): The list of keys used in the redis command. - command (Union[str, Sequence[str]]): The redis command. - """ - for key in keys: - self.key_commands_map[key].remove(command) - - def invalidate_key(self, key: KeyT): - """ - Invalidate (delete) all redis commands associated with a specific key. - - Args: - key (KeyT): The key to be invalidated. - """ - if key not in self.key_commands_map: - return - commands = list(self.key_commands_map[key]) - for command in commands: - self.delete_command(command) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 3547fcf355..281546430b 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -116,6 +116,12 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.handle_push_response( response, disable_decoding, push_request ) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -124,19 +130,10 @@ def _read_response(self, disable_decoding=False, push_request=False): return response def handle_push_response(self, response, disable_decoding, push_request): - if response[0] in _INVALIDATION_MESSAGE: - if self.invalidation_push_handler_func: - res = self.invalidation_push_handler_func(response) - else: - res = None - else: - res = self.pubsub_push_handler_func(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + if response[0] not in _INVALIDATION_MESSAGE: + return self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return self.invalidation_push_handler_func(response) def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -151,7 +148,7 @@ def __init__(self, socket_read_size): self.pubsub_push_handler_func = self.handle_pubsub_push_response self.invalidation_push_handler_func = None - def handle_pubsub_push_response(self, response): + async def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -259,6 +256,12 @@ async def _read_response( response = await self.handle_push_response( response, disable_decoding, push_request ) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -267,19 +270,10 @@ async def _read_response( return response async def handle_push_response(self, response, disable_decoding, push_request): - if response[0] in _INVALIDATION_MESSAGE: - if self.invalidation_push_handler_func: - res = self.invalidation_push_handler_func(response) - else: - res = None - else: - res = self.pubsub_push_handler_func(response) - if not push_request: - return await self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + if response[0] not in _INVALIDATION_MESSAGE: + return await self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return await self.invalidation_push_handler_func(response) def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 70a5e997ef..039ebfdfae 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -26,12 +26,6 @@ cast, ) -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -239,13 +233,6 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 100, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Redis client. @@ -295,13 +282,6 @@ def __init__( "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -626,31 +606,22 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - command_name = args[0] - keys = options.pop("keys", None) # keys are used only for client side caching pool = self.connection_pool + command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) - response_from_cache = await conn._get_from_local_cache(args) + + if self.single_connection_client: + await self._single_conn_lock.acquire() try: - if response_from_cache is not None: - return response_from_cache - else: - try: - if self.single_connection_client: - await self._single_conn_lock.acquire() - response = await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - if keys: - conn._add_to_local_cache(args, response, keys) - return response - finally: - if self.single_connection_client: - self._single_conn_lock.release() + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: + if self.single_connection_client: + self._single_conn_lock.release() if not self.connection: await pool.release(conn) @@ -672,6 +643,9 @@ async def parse_response( if EMPTY_RESPONSE in options: options.pop(EMPTY_RESPONSE) + # Remove keys entry, it needs only for cache. + options.pop("keys", None) + if command_name in self.response_callbacks: # Mypy bug: https://github.com/python/mypy/issues/10977 command_name = cast(str, command_name) @@ -679,24 +653,6 @@ async def parse_response( return await retval if inspect.isawaitable(retval) else retval return response - def flush_cache(self): - if self.connection: - self.connection.flush_cache() - else: - self.connection_pool.flush_cache() - - def delete_command_from_cache(self, command): - if self.connection: - self.connection.delete_command_from_cache(command) - else: - self.connection_pool.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.connection: - self.connection.invalidate_key_from_cache(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - StrictRedis = Redis @@ -1333,7 +1289,6 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: - kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40b2948a7f..4e82e5448f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -19,12 +19,6 @@ Union, ) -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers import AsyncCommandsParser, Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -276,13 +270,6 @@ def __init__( ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 100, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ) -> None: if db: raise RedisClusterException( @@ -326,14 +313,6 @@ def __init__( "socket_timeout": socket_timeout, "retry": retry, "protocol": protocol, - # Client cache related kwargs - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } if ssl: @@ -938,18 +917,6 @@ def lock( thread_local=thread_local, ) - def flush_cache(self): - if self.nodes_manager: - self.nodes_manager.flush_cache() - - def delete_command_from_cache(self, command): - if self.nodes_manager: - self.nodes_manager.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.nodes_manager: - self.nodes_manager.invalidate_key_from_cache(key) - class ClusterNode: """ @@ -1067,6 +1034,9 @@ async def parse_response( if EMPTY_RESPONSE in kwargs: kwargs.pop(EMPTY_RESPONSE) + # Remove keys entry, it needs only for cache. + kwargs.pop("keys", None) + # Return response if command in self.response_callbacks: return self.response_callbacks[command](response, **kwargs) @@ -1076,25 +1046,16 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = self.acquire_connection() - keys = kwargs.pop("keys", None) - response_from_cache = await connection._get_from_local_cache(args) - if response_from_cache is not None: - self._free.append(connection) - return response_from_cache - else: - # Execute command - await connection.send_packed_command(connection.pack_command(*args), False) + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) - # Read response - try: - response = await self.parse_response(connection, args[0], **kwargs) - if keys: - connection._add_to_local_cache(args, response, keys) - return response - finally: - # Release connection - self._free.append(connection) + # Read response + try: + return await self.parse_response(connection, args[0], **kwargs) + finally: + # Release connection + self._free.append(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection @@ -1121,18 +1082,6 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: return ret - def flush_cache(self): - for connection in self._connections: - connection.flush_cache() - - def delete_command_from_cache(self, command): - for connection in self._connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for connection in self._connections: - connection.invalidate_key_from_cache(key) - class NodesManager: __slots__ = ( @@ -1416,18 +1365,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def flush_cache(self): - for node in self.nodes_cache.values(): - node.flush_cache() - - def delete_command_from_cache(self, command): - for node in self.nodes_cache.values(): - node.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for node in self.nodes_cache.values(): - node.invalidate_key_from_cache(key) - class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ @@ -1516,7 +1453,6 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ - kwargs.pop("keys", None) # the keys are used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2ac6637986..ddbd22c95d 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -49,16 +49,9 @@ ResponseError, TimeoutError, ) -from redis.typing import EncodableT, KeysT, ResponseT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes -from .._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, - _LocalCache, -) from .._parsers import ( BaseParser, Encoder, @@ -121,9 +114,6 @@ class AbstractConnection: "encoder", "ssl_context", "protocol", - "client_cache", - "cache_deny_list", - "cache_allow_list", "_reader", "_writer", "_parser", @@ -158,13 +148,6 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): if (username or password) and credential_provider is not None: raise DataError( @@ -222,18 +205,6 @@ def __init__( if p < 2 or p > 3: raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol - if cache_enabled: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) - else: - _cache = None - self.client_cache = client_cache if client_cache is not None else _cache - if self.client_cache is not None: - if self.protocol not in [3, "3"]: - raise RedisError( - "client caching is only supported with protocol version 3 or higher" - ) - self.cache_deny_list = cache_deny_list - self.cache_allow_list = cache_allow_list def __del__(self, _warnings: Any = warnings): # For some reason, the individual streams don't get properly garbage @@ -425,11 +396,6 @@ async def on_connect(self) -> None: # if a database is specified, switch to it. Also pipeline this if self.db: await self.send_command("SELECT", self.db) - # if client caching is enabled, start tracking - if self.client_cache: - await self.send_command("CLIENT", "TRACKING", "ON") - await self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -464,9 +430,6 @@ async def disconnect(self, nowait: bool = False) -> None: raise TimeoutError( f"Timed out closing connection after {self.socket_connect_timeout}" ) from None - finally: - if self.client_cache: - self.client_cache.flush() async def _send_ping(self): """Send PING, expect PONG in return""" @@ -688,60 +651,9 @@ def _socket_is_empty(self): """Check if the socket is empty""" return len(self._reader._buffer) == 0 - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is None: - self.client_cache.flush() - else: - for key in data[1]: - self.client_cache.invalidate_key(str_if_bytes(key)) - - async def _get_from_local_cache(self, command: str): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list - ): - return None + async def process_invalidation_messages(self): while not self._socket_is_empty(): await self.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Tuple[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) - ): - self.client_cache.set(command, response, keys) - - def flush_cache(self): - if self.client_cache: - self.client_cache.flush() - - def delete_command_from_cache(self, command): - if self.client_cache: - self.client_cache.delete_command(command) - - def invalidate_key_from_cache(self, key): - if self.client_cache: - self.client_cache.invalidate_key(key) class Connection(AbstractConnection): @@ -1177,18 +1089,12 @@ def make_connection(self): async def ensure_connection(self, connection: AbstractConnection): """Ensure that the connection object is connected and valid""" await connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) try: - if ( - await connection.can_read_destructive() - and connection.client_cache is None - ): + if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None except (ConnectionError, OSError): await connection.disconnect() @@ -1235,21 +1141,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def flush_cache(self): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.flush_cache() - - def delete_command_from_cache(self, command: str): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key: str): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.invalidate_key_from_cache(key) - class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6fd233adc8..5d4608ed2f 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -225,7 +225,6 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/cache.py b/redis/cache.py new file mode 100644 index 0000000000..9971edd256 --- /dev/null +++ b/redis/cache.py @@ -0,0 +1,401 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional, Union + + +class CacheEntryStatus(Enum): + VALID = "VALID" + IN_PROGRESS = "IN_PROGRESS" + + +class EvictionPolicyType(Enum): + time_based = "time_based" + frequency_based = "frequency_based" + + +@dataclass(frozen=True) +class CacheKey: + command: str + redis_keys: tuple + + +class CacheEntry: + def __init__( + self, + cache_key: CacheKey, + cache_value: bytes, + status: CacheEntryStatus, + connection_ref, + ): + self.cache_key = cache_key + self.cache_value = cache_value + self.status = status + self.connection_ref = connection_ref + + def __hash__(self): + return hash( + (self.cache_key, self.cache_value, self.status, self.connection_ref) + ) + + def __eq__(self, other): + return hash(self) == hash(other) + + +class EvictionPolicyInterface(ABC): + @property + @abstractmethod + def cache(self): + pass + + @cache.setter + def cache(self, value): + pass + + @property + @abstractmethod + def type(self) -> EvictionPolicyType: + pass + + @abstractmethod + def evict_next(self) -> CacheKey: + pass + + @abstractmethod + def evict_many(self, count: int) -> List[CacheKey]: + pass + + @abstractmethod + def touch(self, cache_key: CacheKey) -> None: + pass + + +class CacheConfigurationInterface(ABC): + @abstractmethod + def get_cache_class(self): + pass + + @abstractmethod + def get_max_size(self) -> int: + pass + + @abstractmethod + def get_eviction_policy(self): + pass + + @abstractmethod + def is_exceeds_max_size(self, count: int) -> bool: + pass + + @abstractmethod + def is_allowed_to_cache(self, command: str) -> bool: + pass + + +class CacheInterface(ABC): + @property + @abstractmethod + def collection(self) -> OrderedDict: + pass + + @property + @abstractmethod + def config(self) -> CacheConfigurationInterface: + pass + + @property + @abstractmethod + def eviction_policy(self) -> EvictionPolicyInterface: + pass + + @property + @abstractmethod + def size(self) -> int: + pass + + @abstractmethod + def get(self, key: CacheKey) -> Union[CacheEntry, None]: + pass + + @abstractmethod + def set(self, entry: CacheEntry) -> bool: + pass + + @abstractmethod + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + pass + + @abstractmethod + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + pass + + @abstractmethod + def flush(self) -> int: + pass + + @abstractmethod + def is_cachable(self, key: CacheKey) -> bool: + pass + + +class DefaultCache(CacheInterface): + def __init__( + self, + cache_config: CacheConfigurationInterface, + ) -> None: + self._cache = OrderedDict() + self._cache_config = cache_config + self._eviction_policy = self._cache_config.get_eviction_policy().value() + self._eviction_policy.cache = self + + @property + def collection(self) -> OrderedDict: + return self._cache + + @property + def config(self) -> CacheConfigurationInterface: + return self._cache_config + + @property + def eviction_policy(self) -> EvictionPolicyInterface: + return self._eviction_policy + + @property + def size(self) -> int: + return len(self._cache) + + def set(self, entry: CacheEntry) -> bool: + if not self.is_cachable(entry.cache_key): + return False + + self._cache[entry.cache_key] = entry + self._eviction_policy.touch(entry.cache_key) + + if self._cache_config.is_exceeds_max_size(len(self._cache)): + self._eviction_policy.evict_next() + + return True + + def get(self, key: CacheKey) -> Union[CacheEntry, None]: + entry = self._cache.get(key, None) + + if entry is None: + return None + + self._eviction_policy.touch(key) + return entry + + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + response = [] + + for key in cache_keys: + if self.get(key) is not None: + self._cache.pop(key) + response.append(True) + else: + response.append(False) + + return response + + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + response = [] + keys_to_delete = [] + + for redis_key in redis_keys: + if isinstance(redis_key, bytes): + redis_key = redis_key.decode() + for cache_key in self._cache: + if redis_key in cache_key.redis_keys: + keys_to_delete.append(cache_key) + response.append(True) + + for key in keys_to_delete: + self._cache.pop(key) + + return response + + def flush(self) -> int: + elem_count = len(self._cache) + self._cache.clear() + return elem_count + + def is_cachable(self, key: CacheKey) -> bool: + return self._cache_config.is_allowed_to_cache(key.command) + + +class LRUPolicy(EvictionPolicyInterface): + def __init__(self): + self.cache = None + + @property + def cache(self): + return self._cache + + @cache.setter + def cache(self, cache: CacheInterface): + self._cache = cache + + @property + def type(self) -> EvictionPolicyType: + return EvictionPolicyType.time_based + + def evict_next(self) -> CacheKey: + self._assert_cache() + popped_entry = self._cache.collection.popitem(last=False) + return popped_entry[0] + + def evict_many(self, count: int) -> List[CacheKey]: + self._assert_cache() + if count > len(self._cache.collection): + raise ValueError("Evictions count is above cache size") + + popped_keys = [] + + for _ in range(count): + popped_entry = self._cache.collection.popitem(last=False) + popped_keys.append(popped_entry[0]) + + return popped_keys + + def touch(self, cache_key: CacheKey) -> None: + self._assert_cache() + + if self._cache.collection.get(cache_key) is None: + raise ValueError("Given entry does not belong to the cache") + + self._cache.collection.move_to_end(cache_key) + + def _assert_cache(self): + if self.cache is None or not isinstance(self.cache, CacheInterface): + raise ValueError("Eviction policy should be associated with valid cache.") + + +class EvictionPolicy(Enum): + LRU = LRUPolicy + + +class CacheConfig(CacheConfigurationInterface): + DEFAULT_CACHE_CLASS = DefaultCache + DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU + DEFAULT_MAX_SIZE = 10000 + + DEFAULT_ALLOW_LIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", + "LCS", + "LINDEX", + "LLEN", + "LPOS", + "LRANGE", + "MGET", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "STRLEN", + "SUBSTR", + "SUNION", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", + "TYPE", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCORE", + "ZUNION", + ] + + def __init__( + self, + max_size: int = DEFAULT_MAX_SIZE, + cache_class: Any = DEFAULT_CACHE_CLASS, + eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, + ): + self._cache_class = cache_class + self._max_size = max_size + self._eviction_policy = eviction_policy + + def get_cache_class(self): + return self._cache_class + + def get_max_size(self) -> int: + return self._max_size + + def get_eviction_policy(self) -> EvictionPolicy: + return self._eviction_policy + + def is_exceeds_max_size(self, count: int) -> bool: + return count > self._max_size + + def is_allowed_to_cache(self, command: str) -> bool: + return command in self.DEFAULT_ALLOW_LIST + + +class CacheFactoryInterface(ABC): + @abstractmethod + def get_cache(self) -> CacheInterface: + pass + + +class CacheFactory(CacheFactoryInterface): + def __init__(self, cache_config: Optional[CacheConfig] = None): + self._config = cache_config + + if self._config is None: + self._config = CacheConfig() + + def get_cache(self) -> CacheInterface: + cache_class = self._config.get_cache_class() + return cache_class(cache_config=self._config) diff --git a/redis/client.py b/redis/client.py index b7a1f88d92..bf3432e7eb 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,12 +6,6 @@ from itertools import chain from typing import Any, Callable, Dict, List, Optional, Type, Union -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -19,6 +13,7 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import CacheConfig, CacheInterface from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -216,13 +211,8 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, + cache: Optional[CacheInterface] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -274,13 +264,6 @@ def __init__( "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -322,12 +305,26 @@ def __init__( "ssl_ciphers": ssl_ciphers, } ) + if (cache_config or cache) and protocol in [3, "3"]: + kwargs.update( + { + "cache": cache, + "cache_config": cache_config, + } + ) connection_pool = ConnectionPool(**kwargs) self.auto_close_connection_pool = True else: self.auto_close_connection_pool = False self.connection_pool = connection_pool + + if (cache_config or cache) and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: + raise RedisError("Client caching is only supported with RESP version 3") + self.connection = None if single_connection_client: self.connection = self.connection_pool.get_connection("_") @@ -541,7 +538,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response """ - conn.send_command(*args) + conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) def _disconnect_raise(self, conn, error): @@ -559,25 +556,20 @@ def _disconnect_raise(self, conn, error): # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): + return self._execute_command(*args, **options) + + def _execute_command(self, *args, **options): """Execute a command and return a parsed response""" - command_name = args[0] - keys = options.pop("keys", None) pool = self.connection_pool + command_name = args[0] conn = self.connection or pool.get_connection(command_name, **options) - response_from_cache = conn._get_from_local_cache(args) try: - if response_from_cache is not None: - return response_from_cache - else: - response = conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - if keys: - conn._add_to_local_cache(args, response, keys) - return response + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: if not self.connection: pool.release(conn) @@ -598,27 +590,15 @@ def parse_response(self, connection, command_name, **options): if EMPTY_RESPONSE in options: options.pop(EMPTY_RESPONSE) + # Remove keys entry, it needs only for cache. + options.pop("keys", None) + if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response - def flush_cache(self): - if self.connection: - self.connection.flush_cache() - else: - self.connection_pool.flush_cache() - - def delete_command_from_cache(self, command): - if self.connection: - self.connection.delete_command_from_cache(command) - else: - self.connection_pool.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.connection: - self.connection.invalidate_key_from_cache(key) - else: - self.connection_pool.invalidate_key_from_cache(key) + def get_cache(self) -> Optional[CacheInterface]: + return self.connection_pool.cache StrictRedis = Redis @@ -1314,7 +1294,6 @@ def multi(self) -> None: self.explicit_transaction = True def execute_command(self, *args, **kwargs): - kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) @@ -1441,6 +1420,8 @@ def _execute_transaction(self, connection, commands, raise_on_error) -> List: for r, cmd in zip(response, commands): if not isinstance(r, Exception): args, options = cmd + # Remove keys entry, it needs only for cache. + options.pop("keys", None) command_name = args[0] if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) diff --git a/redis/cluster.py b/redis/cluster.py index be7685e9a1..fbf5428d40 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,6 +9,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff +from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -167,13 +168,8 @@ def parse_cluster_myshardid(resp, **options): "ssl_password", "unix_socket_path", "username", - "cache_enabled", - "client_cache", - "cache_max_size", - "cache_ttl", - "cache_policy", - "cache_deny_list", - "cache_allow_list", + "cache", + "cache_config", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -507,6 +503,8 @@ def __init__( dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + cache: Optional[CacheInterface] = None, + cache_config: Optional[CacheConfig] = None, **kwargs, ): """ @@ -630,6 +628,10 @@ def __init__( kwargs.get("encoding_errors", "strict"), kwargs.get("decode_responses", False), ) + protocol = kwargs.get("protocol", None) + if (cache_config or cache) and protocol not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + self.cluster_error_retry_attempts = cluster_error_retry_attempts self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() @@ -642,6 +644,8 @@ def __init__( require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, + cache=cache, + cache_config=cache_config, **kwargs, ) @@ -649,6 +653,7 @@ def __init__( self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + self.commands_parser = CommandsParser(self) self._lock = threading.Lock() @@ -1052,6 +1057,9 @@ def _parse_target_nodes(self, target_nodes): return nodes def execute_command(self, *args, **kwargs): + return self._internal_execute_command(*args, **kwargs) + + def _internal_execute_command(self, *args, **kwargs): """ Wrapper for ERRORS_ALLOW_RETRY error handling. @@ -1125,7 +1133,6 @@ def _execute_command(self, target_node, *args, **kwargs): """ Send a command to a node in the cluster """ - keys = kwargs.pop("keys", None) command = args[0] redis_node = None connection = None @@ -1154,19 +1161,13 @@ def _execute_command(self, target_node, *args, **kwargs): connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - response_from_cache = connection._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - connection.send_command(*args) - response = redis_node.parse_response(connection, command, **kwargs) - if command in self.cluster_response_callbacks: - response = self.cluster_response_callbacks[command]( - response, **kwargs - ) - if keys: - connection._add_to_local_cache(args, response, keys) - return response + connection.send_command(*args, **kwargs) + response = redis_node.parse_response(connection, command, **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs + ) + return response except AuthenticationError: raise except (ConnectionError, TimeoutError) as e: @@ -1266,18 +1267,6 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) - def flush_cache(self): - if self.nodes_manager: - self.nodes_manager.flush_cache() - - def delete_command_from_cache(self, command): - if self.nodes_manager: - self.nodes_manager.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.nodes_manager: - self.nodes_manager.invalidate_key_from_cache(key) - class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1306,18 +1295,6 @@ def __del__(self): if self.redis_connection is not None: self.redis_connection.close() - def flush_cache(self): - if self.redis_connection is not None: - self.redis_connection.flush_cache() - - def delete_command_from_cache(self, command): - if self.redis_connection is not None: - self.redis_connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.redis_connection is not None: - self.redis_connection.invalidate_key_from_cache(key) - class LoadBalancer: """ @@ -1348,6 +1325,9 @@ def __init__( dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + cache: Optional[CacheInterface] = None, + cache_config: Optional[CacheConfig] = None, + cache_factory: Optional[CacheFactoryInterface] = None, **kwargs, ): self.nodes_cache = {} @@ -1360,6 +1340,9 @@ def __init__( self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap + self._cache = cache + self._cache_config = cache_config + self._cache_factory = cache_factory self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1503,9 +1486,15 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) + kwargs.update({"cache": self._cache}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: - r = Redis(host=host, port=port, **kwargs) + r = Redis( + host=host, + port=port, + cache=self._cache, + **kwargs, + ) return r def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): @@ -1554,6 +1543,7 @@ def initialize(self): # Make sure cluster mode is enabled on this node try: cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + r.connection_pool.disconnect() except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" @@ -1634,6 +1624,12 @@ def initialize(self): f"one reachable node: {str(exception)}" ) from exception + if self._cache is None and self._cache_config is not None: + if self._cache_factory is None: + self._cache = CacheFactory(self._cache_config).get_cache() + else: + self._cache = self._cache_factory.get_cache() + # Create Redis connections to all nodes self.create_redis_connections(list(tmp_nodes_cache.values())) @@ -1681,18 +1677,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def flush_cache(self): - for node in self.nodes_cache.values(): - node.flush_cache() - - def delete_command_from_cache(self, command): - for node in self.nodes_cache.values(): - node.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for node in self.nodes_cache.values(): - node.invalidate_key_from_cache(key) - class ClusterPubSub(PubSub): """ @@ -2008,7 +1992,6 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - kwargs.pop("keys", None) # the keys are used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): @@ -2282,6 +2265,8 @@ def _send_cluster_commands( response = [] for c in sorted(stack, key=lambda x: x.position): if c.args[0] in self.cluster_response_callbacks: + # Remove keys entry, it needs only for cache. + c.options.pop("keys", None) c.result = self.cluster_response_callbacks[c.args[0]]( c.result, **c.options ) diff --git a/redis/commands/core.py b/redis/commands/core.py index d46e55446c..8986a48de2 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5728,7 +5728,7 @@ def script_exists(self, *args: str) -> ResponseT: """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if - if each already script exists in the cache. + if each already script exists in the cache_data. For more information see https://redis.io/commands/script-exists """ @@ -5742,7 +5742,7 @@ def script_debug(self, *args) -> None: def script_flush( self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None ) -> ResponseT: - """Flush all scripts from the script cache. + """Flush all scripts from the script cache_data. ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. @@ -5773,7 +5773,7 @@ def script_kill(self) -> ResponseT: def script_load(self, script: ScriptTextT) -> ResponseT: """ - Load a Lua ``script`` into the script cache. Returns the SHA. + Load a Lua ``script`` into the script cache_data. Returns the SHA. For more information see https://redis.io/commands/script-load """ diff --git a/redis/connection.py b/redis/connection.py index 1f862d0371..6aae2101c2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,16 +9,18 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from ._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, - _LocalCache, +from redis.cache import ( + CacheEntry, + CacheEntryStatus, + CacheFactory, + CacheFactoryInterface, + CacheInterface, + CacheKey, ) + from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -33,12 +35,13 @@ TimeoutError, ) from .retry import Retry -from .typing import KeysT, ResponseT from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, + compare_versions, + ensure_string, format_error_message, get_lib_version, str_if_bytes, @@ -132,7 +135,76 @@ def pack(self, *args): return output -class AbstractConnection: +class ConnectionInterface: + @abstractmethod + def repr_pieces(self): + pass + + @abstractmethod + def register_connect_callback(self, callback): + pass + + @abstractmethod + def deregister_connect_callback(self, callback): + pass + + @abstractmethod + def set_parser(self, parser_class): + pass + + @abstractmethod + def connect(self): + pass + + @abstractmethod + def on_connect(self): + pass + + @abstractmethod + def disconnect(self, *args): + pass + + @abstractmethod + def check_health(self): + pass + + @abstractmethod + def send_packed_command(self, command, check_health=True): + pass + + @abstractmethod + def send_command(self, *args, **kwargs): + pass + + @abstractmethod + def can_read(self, timeout=0): + pass + + @abstractmethod + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): + pass + + @abstractmethod + def pack_command(self, *args): + pass + + @abstractmethod + def pack_commands(self, commands): + pass + + @property + @abstractmethod + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + pass + + +class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" def __init__( @@ -158,13 +230,6 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Connection. @@ -213,6 +278,7 @@ def __init__( self.next_health_check = 0 self.redis_connect_func = redis_connect_func self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size self.set_parser(parser_class) @@ -230,18 +296,6 @@ def __init__( # p = DEFAULT_RESP_VERSION self.protocol = p self._command_packer = self._construct_command_packer(command_packer) - if cache_enabled: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) - else: - _cache = None - self.client_cache = client_cache if client_cache is not None else _cache - if self.client_cache is not None: - if self.protocol not in [3, "3"]: - raise RedisError( - "client caching is only supported with protocol version 3 or higher" - ) - self.cache_deny_list = cache_deny_list - self.cache_allow_list = cache_allow_list def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -367,7 +421,7 @@ def on_connect(self): if len(auth_args) == 1: auth_args = ["default", auth_args[0]] self.send_command("HELLO", self.protocol, "AUTH", *auth_args) - response = self.read_response() + self.handshake_metadata = self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" # ) != self.protocol: @@ -398,10 +452,10 @@ def on_connect(self): self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) self.send_command("HELLO", self.protocol) - response = self.read_response() + self.handshake_metadata = self.read_response() if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol + self.handshake_metadata.get(b"proto") != self.protocol + and self.handshake_metadata.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") @@ -432,12 +486,6 @@ def on_connect(self): if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") - # if client caching is enabled, start tracking - if self.client_cache: - self.send_command("CLIENT", "TRACKING", "ON") - self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) - def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() @@ -458,9 +506,6 @@ def disconnect(self, *args): except OSError: pass - if self.client_cache: - self.client_cache.flush() - def _send_ping(self): """Send PING, expect PONG in return""" self.send_command("PING", check_health=False) @@ -608,60 +653,16 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is None: - self.client_cache.flush() - else: - for key in data[1]: - self.client_cache.invalidate_key(str_if_bytes(key)) - - def _get_from_local_cache(self, command: Sequence[str]): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list - ): - return None - while self.can_read(): - self.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Sequence[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) - ): - self.client_cache.set(command, response, keys) - - def flush_cache(self): - if self.client_cache: - self.client_cache.flush() + def get_protocol(self) -> int or str: + return self.protocol - def delete_command_from_cache(self, command: Union[str, Sequence[str]]): - if self.client_cache: - self.client_cache.delete_command(command) + @property + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + return self._handshake_metadata - def invalidate_key_from_cache(self, key: KeysT): - if self.client_cache: - self.client_cache.invalidate_key(key) + @handshake_metadata.setter + def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): + self._handshake_metadata = value class Connection(AbstractConnection): @@ -734,6 +735,206 @@ def _host_error(self): return f"{self.host}:{self.port}" +class CacheProxyConnection(ConnectionInterface): + DUMMY_CACHE_VALUE = b"foo" + MIN_ALLOWED_VERSION = "7.4.0" + DEFAULT_SERVER_NAME = "redis" + + def __init__( + self, + conn: ConnectionInterface, + cache: CacheInterface, + pool_lock: threading.Lock, + ): + self.pid = os.getpid() + self._conn = conn + self.retry = self._conn.retry + self.host = self._conn.host + self.port = self._conn.port + self._pool_lock = pool_lock + self._cache = cache + self._cache_lock = threading.Lock() + self._current_command_cache_key = None + self._current_options = None + self.register_connect_callback(self._enable_tracking_callback) + + def repr_pieces(self): + return self._conn.repr_pieces() + + def register_connect_callback(self, callback): + self._conn.register_connect_callback(callback) + + def deregister_connect_callback(self, callback): + self._conn.deregister_connect_callback(callback) + + def set_parser(self, parser_class): + self._conn.set_parser(parser_class) + + def connect(self): + self._conn.connect() + + server_name = self._conn.handshake_metadata.get(b"server", None) + if server_name is None: + server_name = self._conn.handshake_metadata.get("server", None) + server_ver = self._conn.handshake_metadata.get(b"version", None) + if server_ver is None: + server_ver = self._conn.handshake_metadata.get("version", None) + if server_ver is None or server_ver is None: + raise ConnectionError("Cannot retrieve information about server version") + + server_ver = ensure_string(server_ver) + server_name = ensure_string(server_name) + + if ( + server_name != self.DEFAULT_SERVER_NAME + or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 + ): + raise ConnectionError( + "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 + ) + + def on_connect(self): + self._conn.on_connect() + + def disconnect(self, *args): + with self._cache_lock: + self._cache.flush() + self._conn.disconnect(*args) + + def check_health(self): + self._conn.check_health() + + def send_packed_command(self, command, check_health=True): + # TODO: Investigate if it's possible to unpack command + # or extract keys from packed command + self._conn.send_packed_command(command) + + def send_command(self, *args, **kwargs): + self._process_pending_invalidations() + + with self._cache_lock: + # Command is write command or not allowed + # to be cached. + if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): + self._current_command_cache_key = None + self._conn.send_command(*args, **kwargs) + return + + if kwargs.get("keys") is None: + raise ValueError("Cannot create cache key.") + + # Creates cache key. + self._current_command_cache_key = CacheKey( + command=args[0], redis_keys=tuple(kwargs.get("keys")) + ) + + with self._cache_lock: + # We have to trigger invalidation processing in case if + # it was cached by another connection to avoid + # queueing invalidations in stale connections. + if self._cache.get(self._current_command_cache_key): + entry = self._cache.get(self._current_command_cache_key) + + if entry.connection_ref != self._conn: + with self._pool_lock: + while entry.connection_ref.can_read(): + entry.connection_ref.read_response(push_request=True) + + return + + # Set temporary entry value to prevent + # race condition from another connection. + self._cache.set( + CacheEntry( + cache_key=self._current_command_cache_key, + cache_value=self.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + connection_ref=self._conn, + ) + ) + + # Send command over socket only if it's allowed + # read-only command that not yet cached. + self._conn.send_command(*args, **kwargs) + + def can_read(self, timeout=0): + return self._conn.can_read(timeout) + + def read_response( + self, disable_decoding=False, *, disconnect_on_error=True, push_request=False + ): + with self._cache_lock: + # Check if command response exists in a cache and it's not in progress. + if ( + self._current_command_cache_key is not None + and self._cache.get(self._current_command_cache_key) is not None + and self._cache.get(self._current_command_cache_key).status + != CacheEntryStatus.IN_PROGRESS + ): + return copy.deepcopy( + self._cache.get(self._current_command_cache_key).cache_value + ) + + response = self._conn.read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + push_request=push_request, + ) + + with self._cache_lock: + # Prevent not-allowed command from caching. + if self._current_command_cache_key is None: + return response + # If response is None prevent from caching. + if response is None: + self._cache.delete_by_cache_keys([self._current_command_cache_key]) + return response + + cache_entry = self._cache.get(self._current_command_cache_key) + + # Cache only responses that still valid + # and wasn't invalidated by another connection in meantime. + if cache_entry is not None: + cache_entry.status = CacheEntryStatus.VALID + cache_entry.cache_value = response + self._cache.set(cache_entry) + + return response + + def pack_command(self, *args): + return self._conn.pack_command(*args) + + def pack_commands(self, commands): + return self._conn.pack_commands(commands) + + @property + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + return self._conn.handshake_metadata + + def _connect(self): + self._conn._connect() + + def _host_error(self): + self._conn._host_error() + + def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: + conn.send_command("CLIENT", "TRACKING", "ON") + conn.read_response() + conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) + + def _process_pending_invalidations(self): + while self.can_read(): + self._conn.read_response(push_request=True) + + def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): + with self._cache_lock: + # Flush cache when DB flushed on server-side + if data[1] is None: + self._cache.flush() + else: + self._cache.delete_by_redis_keys(data[1]) + + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). This class extends the Connection class, adding SSL functionality, and making @@ -1083,6 +1284,7 @@ def __init__( self, connection_class=Connection, max_connections: Optional[int] = None, + cache_factory: Optional[CacheFactoryInterface] = None, **connection_kwargs, ): max_connections = max_connections or 2**31 @@ -1092,6 +1294,30 @@ def __init__( self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections + self.cache = None + self._cache_factory = cache_factory + + if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + + cache = self.connection_kwargs.get("cache") + + if cache is not None: + if not isinstance(cache, CacheInterface): + raise ValueError("Cache must implement CacheInterface") + + self.cache = cache + else: + if self._cache_factory is not None: + self.cache = self._cache_factory.get_cache() + else: + self.cache = CacheFactory( + self.connection_kwargs.get("cache_config") + ).get_cache() + + connection_kwargs.pop("cache", None) + connection_kwargs.pop("cache_config", None) # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as @@ -1110,6 +1336,14 @@ def __repr__(self) -> (str, str): f"({repr(self.connection_class(**self.connection_kwargs))})>" ) + def get_protocol(self): + """ + Returns: + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. + """ + return self.connection_kwargs.get("protocol", None) + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 @@ -1187,15 +1421,12 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": try: # ensure this connection is connected to Redis connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) try: - if connection.can_read() and connection.client_cache is None: + if connection.can_read() and self.cache is None: raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() @@ -1219,11 +1450,17 @@ def get_encoder(self) -> Encoder: decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self) -> "Connection": + def make_connection(self) -> "ConnectionInterface": "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 + + if self.cache is not None: + return CacheProxyConnection( + self.connection_class(**self.connection_kwargs), self.cache, self._lock + ) + return self.connection_class(**self.connection_kwargs) def release(self, connection: "Connection") -> None: @@ -1281,27 +1518,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def flush_cache(self): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.flush_cache() - - def delete_command_from_cache(self, command: str): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key: str): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.invalidate_key_from_cache(key) - class BlockingConnectionPool(ConnectionPool): """ @@ -1379,7 +1595,12 @@ def reset(self): def make_connection(self): "Make a fresh connection." - connection = self.connection_class(**self.connection_kwargs) + if self.cache is not None: + connection = CacheProxyConnection( + self.connection_class(**self.connection_kwargs), self.cache, self._lock + ) + else: + connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection diff --git a/redis/sentinel.py b/redis/sentinel.py index 72b5bef548..01e210794c 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -229,6 +229,7 @@ def __init__( sentinels, min_other_sentinels=0, sentinel_kwargs=None, + force_master_ip=None, **connection_kwargs, ): # if sentinel_kwargs isn't defined, use the socket_* options from @@ -245,6 +246,7 @@ def __init__( ] self.min_other_sentinels = min_other_sentinels self.connection_kwargs = connection_kwargs + self._force_master_ip = force_master_ip def execute_command(self, *args, **kwargs): """ @@ -252,7 +254,6 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") @@ -305,7 +306,13 @@ def discover_master(self, service_name): sentinel, self.sentinels[0], ) - return state["ip"], state["port"] + + ip = ( + self._force_master_ip + if self._force_master_ip is not None + else state["ip"] + ) + return ip, state["port"] error_info = "" if len(collected_errors) > 0: diff --git a/redis/utils.py b/redis/utils.py index a0f31f7ca4..b4e9afb054 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -153,3 +153,42 @@ def format_error_message(host_error: str, exception: BaseException) -> str: f"Error {exception.args[0]} connecting to {host_error}. " f"{exception.args[1]}." ) + + +def compare_versions(version1: str, version2: str) -> int: + """ + Compare two versions. + + :return: -1 if version1 > version2 + 0 if both versions are equal + 1 if version1 < version2 + """ + + num_versions1 = list(map(int, version1.split("."))) + num_versions2 = list(map(int, version2.split("."))) + + if len(num_versions1) > len(num_versions2): + diff = len(num_versions1) - len(num_versions2) + for _ in range(diff): + num_versions2.append(0) + elif len(num_versions1) < len(num_versions2): + diff = len(num_versions2) - len(num_versions1) + for _ in range(diff): + num_versions1.append(0) + + for i, ver in enumerate(num_versions1): + if num_versions1[i] > num_versions2[i]: + return -1 + elif num_versions1[i] < num_versions2[i]: + return 1 + + return 0 + + +def ensure_string(key): + if isinstance(key, bytes): + return key.decode("utf-8") + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") diff --git a/requirements.txt b/requirements.txt index 3274a80f62..622f70b810 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -async-timeout>=4.0.3 +async-timeout>=4.0.3 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index dd78bb6a2c..0c98eee4d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,17 @@ from packaging.version import Version from redis import Sentinel from redis.backoff import NoBackoff -from redis.connection import Connection, parse_url +from redis.cache import ( + CacheConfig, + CacheFactoryInterface, + CacheInterface, + CacheKey, + EvictionPolicy, +) +from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry +from tests.ssl_utils import get_ssl_filename REDIS_INFO = {} default_redis_url = "redis://localhost:6379/0" @@ -321,8 +329,22 @@ def _get_client( kwargs["protocol"] = request.config.getoption("--protocol") cluster_mode = REDIS_INFO["cluster_enabled"] + ssl = kwargs.pop("ssl", False) if not cluster_mode: url_options = parse_url(redis_url) + connection_class = Connection + if ssl: + connection_class = SSLConnection + kwargs["ssl_certfile"] = get_ssl_filename("client-cert.pem") + kwargs["ssl_keyfile"] = get_ssl_filename("client-key.pem") + # When you try to assign "required" as single string + # it assigns tuple instead of string. + # Probably some reserved keyword + # I can't explain how does it work -_- + kwargs["ssl_cert_reqs"] = "require" + "d" + kwargs["ssl_ca_certs"] = get_ssl_filename("ca-cert.pem") + kwargs["port"] = 6666 + kwargs["connection_class"] = connection_class url_options.update(kwargs) pool = redis.ConnectionPool(**url_options) client = cls(connection_pool=pool) @@ -410,18 +432,25 @@ def sslclient(request): @pytest.fixture() -def sentinel_setup(local_cache, request): +def sentinel_setup(request): sentinel_ips = request.config.getoption("--sentinels") sentinel_endpoints = [ (ip.strip(), int(port.strip())) for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) ] kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + cache = request.param.get("cache", None) + cache_config = request.param.get("cache_config", None) + force_master_ip = request.param.get("force_master_ip", None) + decode_responses = request.param.get("decode_responses", False) sentinel = Sentinel( sentinel_endpoints, + force_master_ip=force_master_ip, socket_timeout=0.1, - client_cache=local_cache, + cache=cache, + cache_config=cache_config, protocol=3, + decode_responses=decode_responses, **kwargs, ) yield sentinel @@ -441,7 +470,6 @@ def _gen_cluster_mock_resp(r, response): connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r @@ -514,6 +542,37 @@ def master_host(request): return parts.hostname, (parts.port or 6379) +@pytest.fixture() +def cache_conf() -> CacheConfig: + return CacheConfig(max_size=100, eviction_policy=EvictionPolicy.LRU) + + +@pytest.fixture() +def mock_cache_factory() -> CacheFactoryInterface: + mock_factory = Mock(spec=CacheFactoryInterface) + return mock_factory + + +@pytest.fixture() +def mock_cache() -> CacheInterface: + mock_cache = Mock(spec=CacheInterface) + return mock_cache + + +@pytest.fixture() +def mock_connection() -> ConnectionInterface: + mock_connection = Mock(spec=ConnectionInterface) + return mock_connection + + +@pytest.fixture() +def cache_key(request) -> CacheKey: + command = request.param.get("command") + keys = request.param.get("redis_keys") + + return CacheKey(command, keys) + + def wait_for_command(client, monitor, command, key=None): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6e93407b4c..41b47b2268 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -146,7 +146,6 @@ def _gen_cluster_mock_resp(r, response): connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py deleted file mode 100644 index 7a7f881ce2..0000000000 --- a/tests/test_asyncio/test_cache.py +++ /dev/null @@ -1,408 +0,0 @@ -import time - -import pytest -import pytest_asyncio -from redis._cache import EvictionPolicy, _LocalCache -from redis.utils import HIREDIS_AVAILABLE - - -@pytest_asyncio.fixture -async def r(request, create_redis): - cache = request.param.get("cache") - kwargs = request.param.get("kwargs", {}) - r = await create_redis(protocol=3, client_cache=cache, **kwargs) - yield r, cache - - -@pytest_asyncio.fixture() -async def local_cache(): - yield _LocalCache() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -class TestLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - @pytest.mark.onlynoncluster - async def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) - async def test_cache_lru_eviction(self, r): - r, cache = r - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) - async def test_cache_ttl(self, r): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], - indirect=True, - ) - async def test_cache_lfu_eviction(self, r): - r, cache = r - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - async def test_cache_decode_response(self, r): - r, cache = r - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], - indirect=True, - ) - async def test_cache_deny_list(self, r): - r, cache = r - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], - indirect=True, - ) - async def test_cache_allow_list(self, r): - r, cache = r - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) == 3 - assert cache.get(("LINDEX", "mylist", 1)) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - async def test_cache_return_copy(self, r): - r, cache = r - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"] - res = cache.get(("LRANGE", "mylist", 0, -1)) - assert res == [b"baz", b"bar", b"foo"] - res.append(b"new") - check = cache.get(("LRANGE", "mylist", 0, -1)) - assert check == [b"baz", b"bar", b"foo"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - async def test_csc_not_cause_disconnects(self, r): - r, cache = r - id1 = await r.client_id() - await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}) - assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] - id2 = await r.client_id() - - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] - assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [ - "1", - "1", - "1", - "1", - "1", - ] - - await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2}) - id3 = await r.client_id() - # client should get value from redis server post invalidate messages - assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"] - - await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3}) - # need to check that we get correct value 3 and not 2 - assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] - - await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4}) - # need to check that we get correct value 4 and not 3 - assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] - id4 = await r.client_id() - assert id1 == id2 == id3 == id4 - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert await r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert ( - await r.execute_command("GET", "b") == "2" - ) # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_delete_one_command(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete one command from the cache - r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) - # the other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_invalidate_key(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # invalidate one key from the cache - r.invalidate_key_from_cache("b{a}") - # one other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_flush_entire_cache(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # flush the local cache - r.flush_cache() - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlycluster -class TestClusterLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - async def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - await r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_cache_decode_response(self, r): - r, cache = r - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - await r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert await r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert ( - await r.execute_command("GET", "b") == "2" - ) # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestSentinelLocalCache: - - async def test_get_from_cache(self, local_cache, master): - await master.set("foo", "bar") - # get key from redis and save in local cache - assert await master.get("foo") == b"bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert await master.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "sentinel_setup", - [{"kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_cache_decode_response(self, local_cache, sentinel_setup, master): - await master.set("foo", "bar") - # get key from redis and save in local cache - assert await master.get("foo") == "bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert await master.get("foo") == "barbar" diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index fefa4ef8f9..e480db332b 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -190,7 +190,6 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) @@ -201,7 +200,6 @@ def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc - connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8f79f7d947..e584fc6999 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -75,7 +75,6 @@ async def call_with_retry(self, _, __): mock_conn = mock.AsyncMock(spec=Connection) mock_conn.retry = Retry_() - mock_conn._get_from_local_cache.return_value = None async def get_conn(_): # Validate only one client is created in single-client mode when diff --git a/tests/test_asyncio/test_hash.py b/tests/test_asyncio/test_hash.py index e31ea7eaf3..8d94799fbb 100644 --- a/tests/test_asyncio/test_hash.py +++ b/tests/test_asyncio/test_hash.py @@ -177,7 +177,7 @@ async def test_hexpireat_multiple_fields(r): ) exp_time = int((datetime.now() + timedelta(seconds=1)).timestamp()) assert await r.hexpireat("test:hash", exp_time, "field1", "field2") == [1, 1] - await asyncio.sleep(1.1) + await asyncio.sleep(1.5) assert await r.hexists("test:hash", "field1") is False assert await r.hexists("test:hash", "field2") is False assert await r.hexists("test:hash", "field3") is True diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 19d4b1c650..13a6158b40 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -461,7 +461,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): @pytest.mark.onlynoncluster class TestPubSubRESP3Handler: - def my_handler(self, message): + async def my_handler(self, message): self.message = ["my handler", message] @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") diff --git a/tests/test_cache.py b/tests/test_cache.py index 022364e87a..1803646094 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,106 +1,186 @@ import time -from collections import defaultdict -from typing import List, Sequence, Union -import cachetools import pytest import redis -from redis import RedisError -from redis._cache import AbstractCache, EvictionPolicy, _LocalCache -from redis.typing import KeyT, ResponseT +from redis.cache import ( + CacheConfig, + CacheEntry, + CacheEntryStatus, + CacheKey, + DefaultCache, + EvictionPolicy, + EvictionPolicyType, + LRUPolicy, +) from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import _get_client +from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt @pytest.fixture() def r(request): cache = request.param.get("cache") + cache_config = request.param.get("cache_config") kwargs = request.param.get("kwargs", {}) protocol = request.param.get("protocol", 3) + ssl = request.param.get("ssl", False) single_connection_client = request.param.get("single_connection_client", False) + decode_responses = request.param.get("decode_responses", False) with _get_client( redis.Redis, request, - single_connection_client=single_connection_client, protocol=protocol, - client_cache=cache, + ssl=ssl, + single_connection_client=single_connection_client, + cache=cache, + cache_config=cache_config, + decode_responses=decode_responses, **kwargs, ) as client: - yield client, cache - - -@pytest.fixture() -def local_cache(): - return _LocalCache() + yield client @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -class TestLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) +@pytest.mark.onlynoncluster +# @skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") +class TestCache: + @pytest.mark.parametrize( + "r", + [ + { + "cache": DefaultCache(CacheConfig(max_size=5)), + "single_connection_client": True, + }, + { + "cache": DefaultCache(CacheConfig(max_size=5)), + "single_connection_client": False, + }, + { + "cache": DefaultCache(CacheConfig(max_size=5)), + "single_connection_client": False, + "decode_responses": True, + }, + ], + ids=["single", "pool", "decoded"], + indirect=True, + ) @pytest.mark.onlynoncluster - def test_get_from_cache(self, r, r2): - r, cache = r + def test_get_from_given_cache(self, r, r2): + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" + # Retrieves a new value from server and cache it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(max_size=3)}], + [ + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": True, + }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + "decode_responses": True, + }, + ], + ids=["single", "pool", "decoded"], indirect=True, ) - def test_cache_lru_eviction(self, r): - r, cache = r - # add 3 keys to redis + @pytest.mark.onlynoncluster + def test_get_from_default_cache(self, r, r2): + cache = r.get_cache() + assert isinstance(cache.eviction_policy, LRUPolicy) + assert cache.config.get_max_size() == 128 + + # add key to redis r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None + # get key from redis and save in local cache + assert r.get("foo") in [b"bar", "bar"] + # get key from local cache + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] - @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) - def test_cache_ttl(self, r): - r, cache = r + @pytest.mark.parametrize( + "r", + [ + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": True, + }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_cache_clears_on_disconnect(self, r, cache): + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + # Force disconnection + r.connection_pool.get_connection("_").disconnect() + # Make sure cache is empty + assert cache.size == 0 @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], + [ + { + "cache_config": CacheConfig(max_size=3), + "single_connection_client": True, + }, + { + "cache_config": CacheConfig(max_size=3), + "single_connection_client": False, + }, + ], + ids=["single", "pool"], indirect=True, ) - def test_cache_lfu_eviction(self, r): - r, cache = r + @pytest.mark.onlynoncluster + def test_cache_lru_eviction(self, r, cache): + cache = r.get_cache() # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -109,479 +189,1035 @@ def test_cache_lfu_eviction(self, r): assert r.get("foo") == b"bar" assert r.get("foo2") == b"bar2" assert r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" + # get the 3 keys from local cache + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo2",))).cache_value + == b"bar2" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo3",))).cache_value + == b"bar3" + ) # add 1 more key to redis (exceed the max size) r.set("foo4", "bar4") assert r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None + # the first key is not in the local cache anymore + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None + assert cache.size == 3 @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": True, + }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + }, + ], + ids=["single", "pool"], indirect=True, ) @pytest.mark.onlynoncluster - def test_cache_decode_response(self, r): - r, cache = r - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" + def test_cache_ignore_not_allowed_command(self, r): + cache = r.get_cache() + # add fields to hash + assert r.hset("foo", "bar", "baz") + # get random field + assert r.hrandfield("foo") == b"bar" + assert cache.get(CacheKey(command="HRANDFIELD", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], + [ + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": True, + }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + }, + ], + ids=["single", "pool"], indirect=True, ) - def test_cache_deny_list(self, r): - r, cache = r - # add list to redis - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" + @pytest.mark.onlynoncluster + def test_cache_invalidate_all_related_responses(self, r): + cache = r.get_cache() + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], - indirect=True, - ) - def test_cache_allow_list(self, r): - r, cache = r - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) == 3 - assert cache.get(("LINDEX", "mylist", 1)) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_cache_return_copy(self, r): - r, cache = r - r.lpush("mylist", "foo", "bar", "baz") - assert r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"] - res = cache.get(("LRANGE", "mylist", 0, -1)) - assert res == [b"baz", b"bar", b"foo"] - res.append(b"new") - check = cache.get(("LRANGE", "mylist", 0, -1)) - assert check == [b"baz", b"bar", b"foo"] + res = r.mget("foo", "bar") + # Make sure that replies was cached + assert res == [b"bar", b"foo"] + assert ( + cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))).cache_value + == res + ) + + # Make sure that objects are immutable. + another_res = r.mget("foo", "bar") + res.append(b"baz") + assert another_res != res + + # Invalidate one of the keys and make sure that + # all associated cached entries was removed + assert r.set("foo", "baz") + assert r.get("foo") == b"baz" + assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"baz" + ) @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": True, + }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + }, + ], + ids=["single", "pool"], indirect=True, ) @pytest.mark.onlynoncluster - def test_csc_not_cause_disconnects(self, r): - r, cache = r - id1 = r.client_id() - r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}) - assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] - id2 = r.client_id() - - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] - assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [ - "1", - "1", - "1", - "1", - "1", - "1", - ] + def test_cache_flushed_on_server_flush(self, r): + cache = r.get_cache() + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") + assert r.set("baz", "bar") + + # Make sure that replies was cached + assert r.get("foo") == b"bar" + assert r.get("bar") == b"foo" + assert r.get("baz") == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("bar",))).cache_value + == b"foo" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("baz",))).cache_value + == b"bar" + ) + + # Flush server and trying to access cached entry + assert r.flushall() + assert r.get("foo") is None + assert cache.size == 0 - r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2}) - id3 = r.client_id() - # client should get value from redis server post invalidate messages - assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"] - - r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3}) - # need to check that we get correct value 3 and not 2 - assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] - - r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4}) - # need to check that we get correct value 4 and not 3 - assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] - id4 = r.client_id() - assert id1 == id2 == id3 == id4 +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlycluster +@skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") +class TestClusterCache: @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache": DefaultCache(CacheConfig(max_size=128)), + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "decode_responses": True, + }, + ], indirect=True, ) - @pytest.mark.onlynoncluster - def test_multiple_commands_same_key(self, r): - r, cache = r - r.mset({"a": 1, "b": 1}) - assert r.mget("a", "b") == ["1", "1"] - # value should be in local cache - assert cache.get(("MGET", "a", "b")) == ["1", "1"] - # set only one key - r.set("a", 2) - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("MGET", "a", "b")) is None - # get from redis - assert r.mget("a", "b") == ["2", "1"] + @pytest.mark.onlycluster + def test_get_from_cache(self, r): + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") in [b"bar", "bar"] + # get key from local cache + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] + # Make sure that cache is shared between nodes. + assert ( + cache == r.nodes_manager.get_node_from_slot(1).redis_connection.get_cache() + ) @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + }, + { + "cache_config": CacheConfig(max_size=128), + "decode_responses": True, + }, + ], indirect=True, ) - def test_delete_one_command(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete one command from the cache - r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) - # the other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" + def test_get_from_custom_cache(self, r, r2): + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() + assert isinstance(cache.eviction_policy, LRUPolicy) + assert cache.config.get_max_size() == 128 + + # add key to redis + assert r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") in [b"bar", "bar"] + # get key from local cache + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + }, + ], indirect=True, ) - def test_delete_several_commands(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete the commands from the cache - cache.delete_commands([("MGET", "a{a}", "b{a}"), ("GET", "c")]) - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" + @pytest.mark.onlycluster + def test_cache_clears_on_disconnect(self, r, r2): + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + # Force disconnection + r.nodes_manager.get_node_from_slot( + 12000 + ).redis_connection.connection_pool.get_connection("_").disconnect() + # Make sure cache is empty + assert cache.size == 0 @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=3), + }, + ], indirect=True, ) - def test_invalidate_key(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # invalidate one key from the cache - r.invalidate_key_from_cache("b{a}") - # one other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" + @pytest.mark.onlycluster + def test_cache_lru_eviction(self, r): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + # add 3 keys to redis + r.set("foo{slot}", "bar") + r.set("foo2{slot}", "bar2") + r.set("foo3{slot}", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo{slot}") == b"bar" + assert r.get("foo2{slot}") == b"bar2" + assert r.get("foo3{slot}") == b"bar3" + # get the 3 keys from local cache + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo2{slot}",))).cache_value + == b"bar2" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo3{slot}",))).cache_value + == b"bar3" + ) + # add 1 more key to redis (exceed the max size) + r.set("foo4{slot}", "bar4") + assert r.get("foo4{slot}") == b"bar4" + # the first key is not in the local cache_data anymore + assert cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))) is None @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + }, + ], indirect=True, ) - def test_flush_entire_cache(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # flush the local cache - r.flush_cache() - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.onlynoncluster - def test_cache_not_available_with_resp2(self, request): - with pytest.raises(RedisError) as e: - _get_client(redis.Redis, request, protocol=2, client_cache=_LocalCache()) - assert "protocol version 3 or higher" in str(e.value) + @pytest.mark.onlycluster + def test_cache_ignore_not_allowed_command(self, r): + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() + # add fields to hash + assert r.hset("foo", "bar", "baz") + # get random field + assert r.hrandfield("foo") == b"bar" + assert cache.get(CacheKey(command="HRANDFIELD", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + }, + ], indirect=True, ) - @pytest.mark.onlynoncluster - def test_execute_command_args_not_split(self, r): - r, cache = r - assert r.execute_command("SET a 1") == "OK" - assert r.execute_command("GET a") == "1" - # "get a" is not whitelisted by default, the args should be separated - assert cache.get(("GET a",)) is None + @pytest.mark.onlycluster + def test_cache_invalidate_all_related_responses(self, r, cache): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + # Add keys + assert r.set("foo{slot}", "bar") + assert r.set("bar{slot}", "foo") + + # Make sure that replies was cached + assert r.mget("foo{slot}", "bar{slot}") == [b"bar", b"foo"] + assert cache.get( + CacheKey(command="MGET", redis_keys=("foo{slot}", "bar{slot}")), + ).cache_value == [b"bar", b"foo"] + + # Invalidate one of the keys and make sure + # that all associated cached entries was removed + assert r.set("foo{slot}", "baz") + assert r.get("foo{slot}") == b"baz" + assert ( + cache.get( + CacheKey(command="MGET", redis_keys=("foo{slot}", "bar{slot}")), + ) + is None + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value + == b"baz" + ) @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + }, + ], indirect=True, ) - def test_execute_command_keys_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" + @pytest.mark.onlycluster + def test_cache_flushed_on_server_flush(self, r, cache): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + # Add keys + assert r.set("foo{slot}", "bar") + assert r.set("bar{slot}", "foo") + assert r.set("baz{slot}", "bar") + + # Make sure that replies was cached + assert r.get("foo{slot}") == b"bar" + assert r.get("bar{slot}") == b"foo" + assert r.get("baz{slot}") == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("bar{slot}",))).cache_value + == b"foo" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("baz{slot}",))).cache_value + == b"bar" + ) + # Flush server and trying to access cached entry + assert r.flushall() + assert r.get("foo{slot}") is None + assert cache.size == 0 + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +@skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") +class TestSentinelCache: @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + "sentinel_setup", + [ + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "force_master_ip": "localhost", + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "force_master_ip": "localhost", + "decode_responses": True, + }, + ], indirect=True, ) - def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b") == "2" # keys not provided, not cached - assert cache.get(("GET", "b")) is None + @pytest.mark.onlynoncluster + def test_get_from_cache(self, master): + cache = master.get_cache() + master.set("foo", "bar") + # get key from redis and save in local cache_data + assert master.get("foo") in [b"bar", "bar"] + # get key from local cache_data + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # get key from redis + assert master.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "single_connection_client": True}], + [ + { + "cache_config": CacheConfig(max_size=128), + }, + { + "cache_config": CacheConfig(max_size=128), + "decode_responses": True, + }, + ], indirect=True, ) - @pytest.mark.onlynoncluster - def test_single_connection(self, r): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" + def test_get_from_default_cache(self, r, r2): + cache = r.get_cache() + assert isinstance(cache.eviction_policy, LRUPolicy) - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_get_from_cache_invalidate_via_get(self, r, r2): - r, cache = r # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from redis and save in local cache_data + assert r.get("foo") in [b"bar", "bar"] + # get key from local cache_data + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") - # don't send any command to redis, just run another get - # it should process the invalidation in background - assert r.get("foo") == b"barbar" + # Retrieves a new value from server and cache_data it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] + + @pytest.mark.parametrize( + "sentinel_setup", + [ + { + "cache_config": CacheConfig(max_size=128), + "force_master_ip": "localhost", + } + ], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_cache_clears_on_disconnect(self, master, cache): + cache = master.get_cache() + # add key to redis + master.set("foo", "bar") + # get key from redis and save in local cache_data + assert master.get("foo") == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + # Force disconnection + master.connection_pool.get_connection("_").disconnect() + # Make sure cache_data is empty + assert cache.size == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlycluster -class TestClusterLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_get_from_cache(self, r, r2): - r, cache = r +@pytest.mark.onlynoncluster +@skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") +class TestSSLCache: + @pytest.mark.parametrize( + "r", + [ + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "ssl": True, + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "ssl": True, + "decode_responses": True, + }, + ], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_get_from_cache(self, r, r2, cache): + cache = r.get_cache() # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from redis and save in local cache_data + assert r.get("foo") in [b"bar", "bar"] + # get key from local cache_data + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" + assert r2.set("foo", "barbar") + # Timeout needed for SSL connection because there's timeout + # between data appears in socket buffer + time.sleep(0.1) + # Retrieves a new value from server and cache_data it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + "ssl": True, + }, + { + "cache_config": CacheConfig(max_size=128), + "ssl": True, + "decode_responses": True, + }, + ], indirect=True, ) - def test_cache_decode_response(self, r): - r, cache = r + def test_get_from_custom_cache(self, r, r2): + cache = r.get_cache() + assert isinstance(cache.eviction_policy, LRUPolicy) + + # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" + # get key from redis and save in local cache_data + assert r.get("foo") in [b"bar", "bar"] + # get key from local cache_data + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" + r2.set("foo", "barbar") + # Timeout needed for SSL connection because there's timeout + # between data appears in socket buffer + time.sleep(0.1) + # Retrieves a new value from server and cache_data it + assert r.get("foo") in [b"barbar", "barbar"] + # Make sure that new value was cached + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + [ + { + "cache_config": CacheConfig(max_size=128), + "ssl": True, + } + ], indirect=True, ) - def test_execute_command_keys_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" + @pytest.mark.onlynoncluster + def test_cache_invalidate_all_related_responses(self, r): + cache = r.get_cache() + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") + + # Make sure that replies was cached + assert r.mget("foo", "bar") == [b"bar", b"foo"] + assert cache.get( + CacheKey(command="MGET", redis_keys=("foo", "bar")) + ).cache_value == [b"bar", b"foo"] + + # Invalidate one of the keys and make sure + # that all associated cached entries was removed + assert r.set("foo", "baz") + # Timeout needed for SSL connection because there's timeout + # between data appears in socket buffer + time.sleep(0.1) + assert r.get("foo") == b"baz" + assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"baz" + ) + + +class TestUnitDefaultCache: + def test_get_eviction_policy(self): + cache = DefaultCache(CacheConfig(max_size=5)) + assert isinstance(cache.eviction_policy, LRUPolicy) + + def test_get_max_size(self): + cache = DefaultCache(CacheConfig(max_size=5)) + assert cache.config.get_max_size() == 5 + + def test_get_size(self): + cache = DefaultCache(CacheConfig(max_size=5)) + assert cache.size == 0 @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, + "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) - def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b") == "2" # keys not provided, not cached - assert cache.get(("GET", "b")) is None + def test_set_non_existing_cache_key(self, cache_key, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) + assert cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.get(cache_key).cache_value == b"val" -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestSentinelLocalCache: + @pytest.mark.parametrize( + "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True + ) + def test_set_updates_existing_cache_key(self, cache_key, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) - def test_get_from_cache(self, local_cache, master): - master.set("foo", "bar") - # get key from redis and save in local cache - assert master.get("foo") == b"bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert master.get("foo") == b"barbar" + assert cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.get(cache_key).cache_value == b"val" + + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"new_val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.get(cache_key).cache_value == b"new_val" @pytest.mark.parametrize( - "sentinel_setup", - [{"kwargs": {"decode_responses": True}}], - indirect=True, + "cache_key", [{"command": "HRANDFIELD", "redis_keys": ("bar",)}], indirect=True ) - def test_cache_decode_response(self, local_cache, sentinel_setup, master): - master.set("foo", "bar") - # get key from redis and save in local cache - assert master.get("foo") == "bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert master.get("foo") == "barbar" + def test_set_does_not_store_not_allowed_key(self, cache_key, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) + assert not cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestCustomCache: - class _CustomCache(AbstractCache): - def __init__(self): - self.responses = cachetools.LRUCache(maxsize=1000) - self.keys_to_commands = defaultdict(list) - self.commands_to_keys = defaultdict(list) - - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], + def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection): + cache = DefaultCache(CacheConfig(max_size=3)) + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # Accessing key in the order that it makes 2nd key LRU + assert cache.get(cache_key1).cache_value == b"bar" + assert cache.get(cache_key2).cache_value == b"bar1" + assert cache.get(cache_key3).cache_value == b"bar2" + assert cache.get(cache_key1).cache_value == b"bar" + + cache_key4 = CacheKey(command="GET", redis_keys=("foo3",)) + assert cache.set( + CacheEntry( + cache_key=cache_key4, + cache_value=b"bar3", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # Make sure that new key was added and 2nd is evicted + assert cache.get(cache_key4).cache_value == b"bar3" + assert cache.get(cache_key2) is None + + @pytest.mark.parametrize( + "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True + ) + def test_get_return_correct_value(self, cache_key, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) + + assert cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.get(cache_key).cache_value == b"val" + + wrong_key = CacheKey(command="HGET", redis_keys=("foo",)) + assert cache.get(wrong_key) is None + + result = cache.get(cache_key) + assert cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"new_val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # Make sure that result is immutable. + assert result.cache_value != cache.get(cache_key).cache_value + + def test_delete_by_cache_keys_removes_associated_entries(self, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) + cache_key4 = CacheKey(command="GET", redis_keys=("foo3",)) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert cache.delete_by_cache_keys([cache_key1, cache_key2, cache_key4]) == [ + True, + True, + False, + ] + assert len(cache.collection) == 1 + assert cache.get(cache_key3).cache_value == b"bar2" + + def test_delete_by_redis_keys_removes_associated_entries(self, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="MGET", redis_keys=("foo", "foo3")) + cache_key4 = CacheKey(command="MGET", redis_keys=("foo2", "foo3")) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key4, + cache_value=b"bar3", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert cache.delete_by_redis_keys([b"foo", b"foo1"]) == [True, True, True] + assert len(cache.collection) == 1 + assert cache.get(cache_key4).cache_value == b"bar3" + + def test_flush(self, mock_connection): + cache = DefaultCache(CacheConfig(max_size=5)) + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert cache.flush() == 3 + assert len(cache.collection) == 0 + + +class TestUnitLRUPolicy: + def test_type(self): + policy = LRUPolicy() + assert policy.type == EvictionPolicyType.time_based + + def test_evict_next(self, mock_connection): + cache = DefaultCache( + CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) + ) + policy = cache.eviction_policy + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) + + assert cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert policy.evict_next() == cache_key1 + assert cache.get(cache_key1) is None + + def test_evict_many(self, mock_connection): + cache = DefaultCache( + CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) + ) + policy = cache.eviction_policy + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) + cache_key3 = CacheKey(command="GET", redis_keys=("baz",)) + + assert cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, + cache_value=b"baz", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert policy.evict_many(2) == [cache_key1, cache_key2] + assert cache.get(cache_key1) is None + assert cache.get(cache_key2) is None + + with pytest.raises(ValueError, match="Evictions count is above cache size"): + policy.evict_many(99) + + def test_touch(self, mock_connection): + cache = DefaultCache( + CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) + ) + policy = cache.eviction_policy + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) + + cache.set( + CacheEntry( + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert cache.collection.popitem(last=True)[0] == cache_key2 + cache.set( + CacheEntry( + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + policy.touch(cache_key1) + assert cache.collection.popitem(last=True)[0] == cache_key1 + + def test_throws_error_on_invalid_cache(self): + policy = LRUPolicy() + + with pytest.raises( + ValueError, match="Eviction policy should be associated with valid cache." ): - self.responses[command] = response - for key in keys_in_command: - self.keys_to_commands[key].append(tuple(command)) - self.commands_to_keys[command].append(tuple(keys_in_command)) - - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - return self.responses.get(command) - - def delete_command(self, command: Union[str, Sequence[str]]): - self.responses.pop(command, None) - keys = self.commands_to_keys.pop(command, []) - for key in keys: - if command in self.keys_to_commands[key]: - self.keys_to_commands[key].remove(command) - - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - for command in commands: - self.delete_command(command) - - def flush(self): - self.responses.clear() - self.commands_to_keys.clear() - self.keys_to_commands.clear() - - def invalidate_key(self, key: KeyT): - commands = self.keys_to_commands.pop(key, []) - for command in commands: - self.delete_command(command) - - @pytest.mark.parametrize("r", [{"cache": _CustomCache()}], indirect=True) - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" + policy.evict_next() + + policy.cache = "wrong_type" + + with pytest.raises( + ValueError, match="Eviction policy should be associated with valid cache." + ): + policy.evict_next() + + +class TestUnitCacheConfiguration: + MAX_SIZE = 100 + EVICTION_POLICY = EvictionPolicy.LRU + + def test_get_max_size(self, cache_conf: CacheConfig): + assert self.MAX_SIZE == cache_conf.get_max_size() + + def test_get_eviction_policy(self, cache_conf: CacheConfig): + assert self.EVICTION_POLICY == cache_conf.get_eviction_policy() + + def test_is_exceeds_max_size(self, cache_conf: CacheConfig): + assert not cache_conf.is_exceeds_max_size(self.MAX_SIZE) + assert cache_conf.is_exceeds_max_size(self.MAX_SIZE + 1) + + def test_is_allowed_to_cache(self, cache_conf: CacheConfig): + assert cache_conf.is_allowed_to_cache("GET") + assert not cache_conf.is_allowed_to_cache("SET") diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 5a28f4cde5..c4b3188050 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -208,7 +208,6 @@ def cmd_init_mock(self, r): def mock_node_resp(node, response): connection = Mock() connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -216,7 +215,6 @@ def mock_node_resp(node, response): def mock_node_resp_func(node, func): connection = Mock() connection.read_response.side_effect = func - connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -485,7 +483,6 @@ def mock_execute_command(*_args, **_kwargs): redis_mock_node.execute_command.side_effect = mock_execute_command # Mock response value for all other commands redis_mock_node.parse_response.return_value = "MOCK_OK" - redis_mock_node.connection._get_from_local_cache.return_value = None for node in r.get_nodes(): if node.port != primary.port: node.redis_connection = redis_mock_node @@ -646,10 +643,10 @@ def parse_response_mock_third(connection, *args, **options): mocks["send_command"].assert_has_calls( [ call("READONLY"), - call("GET", "foo"), + call("GET", "foo", keys=["foo"]), call("READONLY"), - call("GET", "foo"), - call("GET", "foo"), + call("GET", "foo", keys=["foo"]), + call("GET", "foo", keys=["foo"]), ] ) @@ -2695,7 +2692,7 @@ def test_init_slots_cache_slots_collision(self, request): def create_mocked_redis_node(host, port, **kwargs): """ - Helper function to return custom slots cache data from + Helper function to return custom slots cache_data data from different redis nodes """ if port == 7000: diff --git a/tests/test_connection.py b/tests/test_connection.py index 69275d58c0..a58703e3b5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,20 +1,34 @@ +import copy +import platform import socket +import threading import types +from typing import Any from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch import pytest import redis from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff +from redis.cache import ( + CacheConfig, + CacheEntry, + CacheEntryStatus, + CacheInterface, + CacheKey, + DefaultCache, + LRUPolicy, +) from redis.connection import ( + CacheProxyConnection, Connection, SSLConnection, UnixDomainSocketConnection, parse_url, ) -from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -346,3 +360,206 @@ def test_unix_socket_connection_failure(): str(e.value) == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." ) + + +class TestUnitConnectionPool: + + @pytest.mark.parametrize( + "max_conn", (-1, "str"), ids=("non-positive", "wrong type") + ) + def test_throws_error_on_incorrect_max_connections(self, max_conn): + with pytest.raises( + ValueError, match='"max_connections" must be a positive integer' + ): + ConnectionPool( + max_connections=max_conn, + ) + + def test_throws_error_on_cache_enable_in_resp2(self): + with pytest.raises( + RedisError, match="Client caching is only supported with RESP version 3" + ): + ConnectionPool(protocol=2, cache_config=CacheConfig()) + + def test_throws_error_on_incorrect_cache_implementation(self): + with pytest.raises(ValueError, match="Cache must implement CacheInterface"): + ConnectionPool(protocol=3, cache="wrong") + + def test_returns_custom_cache_implementation(self, mock_cache): + connection_pool = ConnectionPool(protocol=3, cache=mock_cache) + + assert mock_cache == connection_pool.cache + connection_pool.disconnect() + + def test_creates_cache_with_custom_cache_factory( + self, mock_cache_factory, mock_cache + ): + mock_cache_factory.get_cache.return_value = mock_cache + + connection_pool = ConnectionPool( + protocol=3, + cache_config=CacheConfig(max_size=5), + cache_factory=mock_cache_factory, + ) + + assert connection_pool.cache == mock_cache + connection_pool.disconnect() + + def test_creates_cache_with_given_configuration(self, mock_cache): + connection_pool = ConnectionPool( + protocol=3, cache_config=CacheConfig(max_size=100) + ) + + assert isinstance(connection_pool.cache, CacheInterface) + assert connection_pool.cache.config.get_max_size() == 100 + assert isinstance(connection_pool.cache.eviction_policy, LRUPolicy) + connection_pool.disconnect() + + def test_make_connection_proxy_connection_on_given_cache(self): + connection_pool = ConnectionPool(protocol=3, cache_config=CacheConfig()) + + assert isinstance(connection_pool.make_connection(), CacheProxyConnection) + connection_pool.disconnect() + + +class TestUnitCacheProxyConnection: + def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): + cache = DefaultCache(CacheConfig(max_size=10)) + cache_key = CacheKey(command="GET", redis_keys=("foo",)) + + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert cache.get(cache_key).cache_value == b"bar" + + mock_connection.disconnect.return_value = None + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.Lock() + ) + proxy_connection.disconnect() + + assert len(cache.collection) == 0 + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + + mock_cache.is_cachable.return_value = True + mock_cache.get.side_effect = [ + None, + None, + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + connection_ref=mock_connection, + ), + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ), + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ), + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ), + ] + mock_connection.send_command.return_value = Any + mock_connection.read_response.return_value = b"bar" + mock_connection.can_read.return_value = False + + proxy_connection = CacheProxyConnection( + mock_connection, mock_cache, threading.Lock() + ) + proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) + assert proxy_connection.read_response() == b"bar" + assert proxy_connection.read_response() == b"bar" + + mock_connection.read_response.assert_called_once() + mock_cache.set.assert_has_calls( + [ + call( + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + connection_ref=mock_connection, + ) + ), + call( + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ), + ] + ) + + mock_cache.get.assert_has_calls( + [ + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + ] + ) + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_triggers_invalidation_processing_on_another_connection( + self, mock_cache, mock_connection + ): + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + + another_conn = copy.deepcopy(mock_connection) + another_conn.can_read.side_effect = [True, False] + another_conn.read_response.return_value = None + cache_entry = CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=another_conn, + ) + mock_cache.is_cachable.return_value = True + mock_cache.get.return_value = cache_entry + mock_connection.can_read.return_value = False + + proxy_connection = CacheProxyConnection( + mock_connection, mock_cache, threading.Lock() + ) + proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) + + assert proxy_connection.read_response() == b"bar" + assert another_conn.can_read.call_count == 2 + another_conn.read_response.assert_called_once() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..764ef5d0a9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,27 @@ +import pytest +from redis.utils import compare_versions + + +@pytest.mark.parametrize( + "version1,version2,expected_res", + [ + ("1.0.0", "0.9.0", -1), + ("1.0.0", "1.0.0", 0), + ("0.9.0", "1.0.0", 1), + ("1.09.0", "1.9.0", 0), + ("1.090.0", "1.9.0", -1), + ("1", "0.9.0", -1), + ("1", "1.0.0", 0), + ], + ids=[ + "version1 > version2", + "version1 == version2", + "version1 < version2", + "version1 == version2 - different minor format", + "version1 > version2 - different minor format", + "version1 > version2 - major version only", + "version1 == version2 - major version only", + ], +) +def test_compare_versions(version1, version2, expected_res): + assert compare_versions(version1, version2) == expected_res