From 5007dc25278877a47b94a4dbe43a4afec27951ac Mon Sep 17 00:00:00 2001 From: Hai Zhou <2293634+haizhou@users.noreply.github.com> Date: Fri, 13 Jan 2023 17:09:02 -0500 Subject: [PATCH] FIX: preserve original keys after key_builder transformation (#564) Co-authored-by: Hai Zhou Co-authored-by: Sam Bull --- aiocache/decorators.py | 20 +++++++++----------- tests/ut/test_decorators.py | 21 ++++++++++++++++----- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/aiocache/decorators.py b/aiocache/decorators.py index db97319e..0fb1f329 100644 --- a/aiocache/decorators.py +++ b/aiocache/decorators.py @@ -7,7 +7,6 @@ from aiocache.factory import Cache, caches from aiocache.lock import RedLock - logger = logging.getLogger(__name__) @@ -344,19 +343,19 @@ async def decorator( ): missing_keys = [] partial = {} - keys, new_args, args_index = self.get_cache_keys(f, args, kwargs) + orig_keys, cache_keys, new_args, args_index = self.get_cache_keys(f, args, kwargs) if cache_read: - values = await self.get_from_cache(*keys) - for key, value in zip(keys, values): + values = await self.get_from_cache(*cache_keys) + for orig_key, value in zip(orig_keys, values): if value is None: - missing_keys.append(key) + missing_keys.append(orig_key) else: - partial[key] = value + partial[orig_key] = value if values and None not in values: return partial else: - missing_keys = list(keys) + missing_keys = list(orig_keys) if args_index > -1: new_args[args_index] = missing_keys @@ -377,17 +376,16 @@ async def decorator( def get_cache_keys(self, f, args, kwargs): args_dict = _get_args_dict(f, args, kwargs) - keys = args_dict.get(self.keys_from_attr, []) or [] - keys = [self.key_builder(key, f, *args, **kwargs) for key in keys] + orig_keys = args_dict.get(self.keys_from_attr, []) or [] + cache_keys = [self.key_builder(key, f, *args, **kwargs) for key in orig_keys] args_names = f.__code__.co_varnames[: f.__code__.co_argcount] new_args = list(args) keys_index = -1 if self.keys_from_attr in args_names and self.keys_from_attr not in kwargs: keys_index = args_names.index(self.keys_from_attr) - new_args[keys_index] = keys - return keys, new_args, keys_index + return orig_keys, cache_keys, new_args, keys_index async def get_from_cache(self, *keys): if not keys: diff --git a/tests/ut/test_decorators.py b/tests/ut/test_decorators.py index e4a1a07e..d01a9fac 100644 --- a/tests/ut/test_decorators.py +++ b/tests/ut/test_decorators.py @@ -395,26 +395,27 @@ def test_alias_takes_precedence(self, mock_cache): def test_get_cache_keys(self, decorator): keys = decorator.get_cache_keys(stub_dict, (), {"keys": ["a", "b"]}) - assert keys == (["a", "b"], [], -1) + assert keys == (["a", "b"], ["a", "b"], [], -1) def test_get_cache_keys_empty_list(self, decorator): - assert decorator.get_cache_keys(stub_dict, (), {"keys": []}) == ([], [], -1) + assert decorator.get_cache_keys(stub_dict, (), {"keys": []}) == ([], [], [], -1) def test_get_cache_keys_missing_kwarg(self, decorator): - assert decorator.get_cache_keys(stub_dict, (), {}) == ([], [], -1) + assert decorator.get_cache_keys(stub_dict, (), {}) == ([], [], [], -1) def test_get_cache_keys_arg_key_from_attr(self, decorator): def fake(keys, a=1, b=2): """Dummy function.""" - assert decorator.get_cache_keys(fake, (["a"]), {}) == (["a"], [["a"]], 0) + assert decorator.get_cache_keys(fake, (["a"],), {}) == (["a"], ["a"], [["a"]], 0) def test_get_cache_keys_with_none(self, decorator): - assert decorator.get_cache_keys(stub_dict, (), {"keys": None}) == ([], [], -1) + assert decorator.get_cache_keys(stub_dict, (), {"keys": None}) == ([], [], [], -1) def test_get_cache_keys_with_key_builder(self, decorator): decorator.key_builder = lambda key, *args, **kwargs: kwargs["market"] + "_" + key.upper() assert decorator.get_cache_keys(stub_dict, (), {"keys": ["a", "b"], "market": "ES"}) == ( + ["a", "b"], ["ES_A", "ES_B"], [], -1, @@ -590,6 +591,16 @@ async def bar(): assert foo.cache != bar.cache + async def test_key_builder(self): + @multi_cached("keys", key_builder=lambda key, _, keys: key + 1) + async def f(keys=None): + return {k: k * 3 for k in keys} + + assert await f(keys=(1,)) == {1: 3} + cached_value = await f.cache.get(2) + assert cached_value == 3 + assert not await f.cache.exists(1) + def test_get_args_dict(): def fn(a, b, *args, keys=None, **kwargs):