Skip to content

Commit

Permalink
FIX: preserve original keys after key_builder transformation (#564)
Browse files Browse the repository at this point in the history
Co-authored-by: Hai Zhou <haizhou@users.noreply.github.com>
Co-authored-by: Sam Bull <aa6bs0@sambull.org>
  • Loading branch information
3 people authored Jan 13, 2023
1 parent d2d91c3 commit 5007dc2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
20 changes: 9 additions & 11 deletions aiocache/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from aiocache.factory import Cache, caches
from aiocache.lock import RedLock


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -344,19 +343,19 @@ async def decorator(
):
missing_keys = []
partial = {}
keys, new_args, args_index = self.get_cache_keys(f, args, kwargs)
orig_keys, cache_keys, new_args, args_index = self.get_cache_keys(f, args, kwargs)

if cache_read:
values = await self.get_from_cache(*keys)
for key, value in zip(keys, values):
values = await self.get_from_cache(*cache_keys)
for orig_key, value in zip(orig_keys, values):
if value is None:
missing_keys.append(key)
missing_keys.append(orig_key)
else:
partial[key] = value
partial[orig_key] = value
if values and None not in values:
return partial
else:
missing_keys = list(keys)
missing_keys = list(orig_keys)

if args_index > -1:
new_args[args_index] = missing_keys
Expand All @@ -377,17 +376,16 @@ async def decorator(

def get_cache_keys(self, f, args, kwargs):
args_dict = _get_args_dict(f, args, kwargs)
keys = args_dict.get(self.keys_from_attr, []) or []
keys = [self.key_builder(key, f, *args, **kwargs) for key in keys]
orig_keys = args_dict.get(self.keys_from_attr, []) or []
cache_keys = [self.key_builder(key, f, *args, **kwargs) for key in orig_keys]

args_names = f.__code__.co_varnames[: f.__code__.co_argcount]
new_args = list(args)
keys_index = -1
if self.keys_from_attr in args_names and self.keys_from_attr not in kwargs:
keys_index = args_names.index(self.keys_from_attr)
new_args[keys_index] = keys

return keys, new_args, keys_index
return orig_keys, cache_keys, new_args, keys_index

async def get_from_cache(self, *keys):
if not keys:
Expand Down
21 changes: 16 additions & 5 deletions tests/ut/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,26 +395,27 @@ def test_alias_takes_precedence(self, mock_cache):

def test_get_cache_keys(self, decorator):
keys = decorator.get_cache_keys(stub_dict, (), {"keys": ["a", "b"]})
assert keys == (["a", "b"], [], -1)
assert keys == (["a", "b"], ["a", "b"], [], -1)

def test_get_cache_keys_empty_list(self, decorator):
assert decorator.get_cache_keys(stub_dict, (), {"keys": []}) == ([], [], -1)
assert decorator.get_cache_keys(stub_dict, (), {"keys": []}) == ([], [], [], -1)

def test_get_cache_keys_missing_kwarg(self, decorator):
assert decorator.get_cache_keys(stub_dict, (), {}) == ([], [], -1)
assert decorator.get_cache_keys(stub_dict, (), {}) == ([], [], [], -1)

def test_get_cache_keys_arg_key_from_attr(self, decorator):
def fake(keys, a=1, b=2):
"""Dummy function."""

assert decorator.get_cache_keys(fake, (["a"]), {}) == (["a"], [["a"]], 0)
assert decorator.get_cache_keys(fake, (["a"],), {}) == (["a"], ["a"], [["a"]], 0)

def test_get_cache_keys_with_none(self, decorator):
assert decorator.get_cache_keys(stub_dict, (), {"keys": None}) == ([], [], -1)
assert decorator.get_cache_keys(stub_dict, (), {"keys": None}) == ([], [], [], -1)

def test_get_cache_keys_with_key_builder(self, decorator):
decorator.key_builder = lambda key, *args, **kwargs: kwargs["market"] + "_" + key.upper()
assert decorator.get_cache_keys(stub_dict, (), {"keys": ["a", "b"], "market": "ES"}) == (
["a", "b"],
["ES_A", "ES_B"],
[],
-1,
Expand Down Expand Up @@ -590,6 +591,16 @@ async def bar():

assert foo.cache != bar.cache

async def test_key_builder(self):
@multi_cached("keys", key_builder=lambda key, _, keys: key + 1)
async def f(keys=None):
return {k: k * 3 for k in keys}

assert await f(keys=(1,)) == {1: 3}
cached_value = await f.cache.get(2)
assert cached_value == 3
assert not await f.cache.exists(1)


def test_get_args_dict():
def fn(a, b, *args, keys=None, **kwargs):
Expand Down

0 comments on commit 5007dc2

Please sign in to comment.