Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement and use an @lru_cache decorator #8595

Merged
merged 6 commits into from
Oct 30, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 80 additions & 38 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@
import functools
import inspect
import logging
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from typing import (
Any,
Callable,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
from weakref import WeakValueDictionary

from twisted.internet import defer
Expand Down Expand Up @@ -97,6 +109,10 @@ def __init__(self, orig: _CachedFunction, num_args, cache_context=False):

self.add_cache_context = cache_context

self.cache_key_builder = get_cache_key_builder(
self.arg_names, self.arg_defaults
)


class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
Expand Down Expand Up @@ -141,7 +157,6 @@ def __init__(
cache_context=False,
iterable=False,
):

super().__init__(orig, num_args=num_args, cache_context=cache_context)

self.max_entries = max_entries
Expand All @@ -157,41 +172,7 @@ def __get__(self, obj, owner):
iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any]

def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.

We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]

# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if self.num_args == 1:
nm = self.arg_names[0]

def get_cache_key(args, kwargs):
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return self.arg_defaults[nm]

else:

def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
get_cache_key = self.cache_key_builder

@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
Expand Down Expand Up @@ -223,7 +204,6 @@ def _wrapped(*args, **kwargs):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a duplicate of a line just below: removing while I'm here.

wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill

Expand Down Expand Up @@ -468,3 +448,65 @@ def batch_do_something(self, first_arg, second_args):
)

return cast(Callable[[F], _CachedFunction[F]], func)


def get_cache_key_builder(
param_names: Sequence[str], param_defaults: Mapping[str, Any]
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
"""Construct a function which will build cache keys suitable for a cached function

Args:
param_names: list of formal parameter names for the cached function
param_defaults: a mapping from parameter name to default value for that param

Returns:
A function which will take an (args, kwargs) pair and return a cache key
"""

# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.

if len(param_names) == 1:
nm = param_names[0]

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return param_defaults[nm]

else:

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))

return get_cache_key


def _get_cache_key_gen(
param_names: Iterable[str],
param_defaults: Mapping[str, Any],
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> Iterable[Any]:
"""Given some args/kwargs return a generator that resolves into
the cache_key.

This is essentially the same operation as `inspect.getcallargs`, but optimised so
that we don't need to inspect the target function for each call.
"""

# We loop through each arg name, looking up if its in the `kwargs`,
# otherwise using the next argument in `args`. If there are no more
# args then we try looking the arg name up in the defaults.
pos = 0
for nm in param_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield param_defaults[nm]