Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Speed up @cachedList #13591

Merged
merged 10 commits into from
Aug 23, 2022
195 changes: 106 additions & 89 deletions synapse/util/caches/deferred_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import enum
import threading
from typing import (
Any,
Callable,
Collection,
Dict,
Generic,
Iterable,
MutableMapping,
Optional,
Set,
Sized,
TypeVar,
Union,
Expand All @@ -31,7 +36,6 @@
from prometheus_client import Gauge

from twisted.internet import defer
from twisted.python import failure
from twisted.python.failure import Failure

from synapse.util.async_helpers import ObservableDeferred
Expand Down Expand Up @@ -159,15 +163,16 @@ def get(
Raises:
KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
val.add_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred.observe()
return val.deferred(key)

callbacks = (callback,) if callback else ()

val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
Expand Down Expand Up @@ -218,84 +223,70 @@ def set(
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")

callbacks = [callback] if callback else []
self.check_thread()

existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache.pop(key, None)

# XXX: why don't we invalidate the entry in `self.cache` yet?

# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, cast(VT, result), callbacks)
return value

# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
entry = CacheEntrySingle[KT, VT](value)
entry.add_callback(key, callback)
self._pending_deferred_cache[key] = entry
deferred = entry.deferred(key).addCallbacks(
self._set_completed_callback,
self._error_callback,
callbackArgs=(entry, key),
errbackArgs=(entry, key),
)

observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
# we return a new Deferred which will be called before any subsequent observers.
return deferred

def _set_completed_callback(
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
) -> VT:
"""Called when a deferred is completed."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return value

self._pending_deferred_cache[key] = entry
self.cache.set(key, value, entry.get_callbacks(key))

def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.

Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True

# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

return False

def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()

def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()

# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return value

# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
def _error_callback(
self,
failure: Failure,
entry: "CacheEntry[KT, VT]",
key: KT,
) -> Failure:
"""Called when a deferred errors."""

# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return failure

for cb in entry.get_callbacks(key):
cb()

return failure

def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
callbacks = [callback] if callback else []
callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
self._pending_deferred_cache.pop(key, None)

def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
Expand All @@ -311,41 +302,67 @@ def invalidate(self, key: KT) -> None:
self.cache.del_multi(key)

# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# _pending_deferred_cache, which will (a) stop it being returned for
# future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)

# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
for cb in entry.get_callbacks(key):
cb()

def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
for key, entry in self._pending_deferred_cache.items():
for cb in entry.get_callbacks(key):
cb()

self._pending_deferred_cache.clear()


class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
"""Abstract class for entries in `DeferredCache[KT, VT]`"""

def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False

def invalidate(self) -> None:
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
@abc.abstractmethod
def deferred(self, key: KT) -> "defer.Deferred[VT]":
"""Get a deferred that a caller can wait on to get the value at the
given key"""
...

@abc.abstractmethod
def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None:
"""Add an invalidation callback"""
...

@abc.abstractmethod
def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
"""Get all invalidation callbacks"""
...


class CacheEntrySingle(CacheEntry[KT, VT]):
"""An implementation of `CacheEntry` wrapping a deferred that results in a
single cache entry.
"""

__slots__ = ["_deferred", "_callbacks"]

def __init__(self, deferred: "defer.Deferred[VT]") -> None:
self._deferred = ObservableDeferred(deferred, consumeErrors=True)
self._callbacks: Set[Callable[[], None]] = set()

def deferred(self, key: KT) -> "defer.Deferred[VT]":
return self._deferred.observe()

def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None:
if callback is None:
return

self._callbacks.add(callback)

def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks
3 changes: 3 additions & 0 deletions synapse/util/caches/treecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def pop(self, key, default=None):
def values(self):
return iterate_tree_cache_entry(self.root)

def items(self):
return iterate_tree_cache_items((), self.root)

def __len__(self) -> int:
return self.size

Expand Down