Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved caching #165

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 141 additions & 42 deletions pyterrier/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import os
from os import path
CACHE_DIR = None
DEFAULT_CACHE_STORE = "shelve" #or "chest"
import pandas as pd
import pickle
from functools import partial
import datetime
from warnings import warn
from typing import List

DEFINITION_FILE = ".transformer"

Expand Down Expand Up @@ -46,11 +48,18 @@ def list_cache():
if path.exists(def_file):
with open(def_file, "r") as f:
elem["transformer"] = f.readline()
elem["size"] = sum(d.stat().st_size for d in os.scandir(dir) if d.is_file())
elem["size_str"] = sizeof_fmt(elem["size"])
elem["queries"] = len(os.listdir(dir)) -2 #subtract .keys and DEFINITION_FILE
elem["lastmodified"] = path.getmtime(dir)
shelve_file = path.join(dir, "shelve.db")
if path.exists(shelve_file):
elem["size"] = path.getsize(shelve_file)
elem["lastmodified"] = path.getmtime(shelve_file)
else:
#we assume it is a chest
elem["size"] = sum(d.stat().st_size for d in os.scandir(dir) if d.is_file())
elem["queries"] = len(os.listdir(dir)) -2 #subtract .keys and DEFINITION_FILE
elem["lastmodified"] = path.getmtime(dir)

elem["lastmodified_str"] = datetime.datetime.fromtimestamp(elem["lastmodified"]).strftime('%Y-%m-%dT%H:%M:%S')
elem["size_str"] = sizeof_fmt(elem["size"])
rtr[dirname] = elem
return rtr

Expand All @@ -61,15 +70,18 @@ def clear_cache():
import shutil
shutil.rmtree(CACHE_DIR)

class ChestCacheTransformer(TransformerBase):

class GenericCacheTransformer(TransformerBase):
"""
A transformer that cache the results of the consituent (inner) transformer.
This is instantiated using the `~` operator on any transformer.

Caching is unqiue based on the configuration of the pipeline, as read by executing
retr() on the pipeline. Caching lookup is based on the qid, so any change in query
Caching is based on the configuration of the pipeline, as read by executing
repr() on the pipeline. Caching lookup is by default based on the qid, so any change in query
_formulation_ will not be reflected in a cache's results.

Caching lookup can be changed by altering the `on` attribute in the cache object.

Example Usage::

dataset = pt.get_dataset("trec-robust-2004")
Expand All @@ -90,39 +102,42 @@ class ChestCacheTransformer(TransformerBase):
In the above example, we use the `~` operator on the first pass retrieval using BM25, but not on the 2nd pass retrieval,
as the query formulation will differ during the second pass.

Caching is not supported for re-ranking transformers.

"""

def __init__(self, inner, **kwargs):
def __init__(self, inner, on=["qid"], verbose=False, debug=False, **kwargs):
super().__init__(**kwargs)
on="qid"
self.on = on
self.inner = inner
self.disable = False
self.hits = 0
self.requests = 0
self.debug = debug
self.verbose = verbose

if CACHE_DIR is None:
init()

# we take the md5 of the __repr__ of the pipeline to make a unique identifier for the pipeline
# all different pipelines should return unique __repr_() values, as these are intended to be
# unambiguous
trepr = repr(self.inner)
if "object at 0x" in trepr:
warn("Cannot cache pipeline %s has a component has not overridden __repr__" % trepr)
self.disable = True
self.trepr = repr(self.inner)
if "object at 0x" in self.trepr:
warn("Cannot cache pipeline %s across PyTerrier sessions, as it has a transient component, which has not overridden __repr__()" % self.trepr)
#return
#self.disable = True

uid = hashlib.md5( bytes(trepr, "utf-8") ).hexdigest()
destdir = path.join(CACHE_DIR, uid)
os.makedirs(destdir, exist_ok=True)
definition_file=path.join(destdir, DEFINITION_FILE)
uid = hashlib.md5( bytes(self.trepr, "utf-8") ).hexdigest()
self.destdir = path.join(CACHE_DIR, uid)
os.makedirs(self.destdir, exist_ok=True)

definition_file=path.join(self.destdir, DEFINITION_FILE)
if not path.exists(definition_file):
if self.debug:
print("Creating new cache store at %s for %s" % (self.destdir, self.trepr))
with open(definition_file, "w") as f:
f.write(trepr)
self.chest = Chest(path=destdir,
dump=lambda data, filename: pd.DataFrame.to_pickle(data, filename) if isinstance(data, pd.DataFrame) else pickle.dump(data, filename, protocol=1),
load=lambda filehandle: pickle.load(filehandle) if ".keys" in filehandle.name else pd.read_pickle(filehandle)
)
self.hits = 0
self.requests = 0

f.write(self.trepr)

def stats(self):
return self.hits / self.requests if self.requests > 0 else 0

Expand All @@ -136,44 +151,128 @@ def __repr__(self):
def __str__(self):
return "Cache("+str(self.inner)+")"

def __del__(self):
self.close()

@property
def NOCACHE(self):
return self.inner

def flush(self):
self.chest.flush()

def close(self):
pass

def transform(self, input_res):
if self.disable:
return self.inner.transform(input_res)
if "docid" in input_res.columns or "docno" in input_res.columns:
raise ValueError("Caching of %s for re-ranking is not supported. Caching currently only supports input dataframes with queries as inputs and cannot be used for re-rankers." % self.inner.__repr__())
for col in self.on:
if col not in input_res.columns:
raise ValueError("Caching on %s, but did not find column %s among input columns %s"
% (str(self.on)), col, str(input_res.columns))
for col in ["docno"]:
if col in input_res.columns and not col in self.on and len(self.on) == 1:
warn(("Caching on=%s, but found column %s among input columns %s. You may want " % (str(self.on)), col, str(input_res.columns) ) +
"to update the on attribute for the cache transformer")
return self._transform_qid(input_res)

def _transform_qid(self, input_res):
# output dataframes to /return/
rtr = []
# input rows to execute on the inner transformer
todo=[]

# We cannot remove this iterrows() without knowing how to take named tuples into a dataframe
for index, row in input_res.iterrows():
qid = str(row["qid"])
import pyterrier as pt
iter = input_res.itertuples(index=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the global iter function is shadowed here; rename to it or iterator?

for row in pt.tqdm(
iter,
desc="%s lookups" % self,
unit='row',
total=len(input_res)) if self.verbose else iter:
# we calculate what we will key this cache on
key = ''.join([getattr(row, k) for k in self.on])
qid = str(row.qid)
self.requests += 1
try:
df = self.chest.get(qid, None)
df = self.chest.get(key, None)
except:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more specific exception handler here?

# occasionally we have file not founds,
# lets remove from the cache and continue
del self.chest[qid]
del self.chest[key]
df = None
if df is None:
todo.append(row.to_frame().T)
if self.debug:
print("%s cache miss for key %s" % (self, key))
todo.append(row)
else:
if self.debug:
print("%s cache hit for key %s" % (self, key))
self.hits += 1
rtr.append(df)
if len(todo) > 0:
todo_df = pd.concat(todo)
todo_df = pd.DataFrame(todo)
todo_res = self.inner.transform(todo_df)
for row in todo_df.itertuples():
qid = row.qid
this_query_res = todo_res[todo_res["qid"] == qid]
self.chest[qid] = this_query_res
rtr.append(this_query_res)
self.chest.flush()
for key_vals, group in todo_res.groupby(self.on):
key = ''.join(key_vals)
self.chest[key] = group
if self.debug:
print("%s caching %d results for key %s" % (self, len(group), key))
rtr.append(todo_res)
self.flush()
return pd.concat(rtr)


class ChestCacheTransformer(GenericCacheTransformer):
"""
A cache transformer based on `chest <https://github.com/blaze/chest>`_.
"""
def __init__(self, inner, **kwargs):
super().__init__(inner, **kwargs)

self.chest = Chest(path=self.destdir,
dump=lambda data, filename: pd.DataFrame.to_pickle(data, filename) if isinstance(data, pd.DataFrame) else pickle.dump(data, filename, protocol=1),
load=lambda filehandle: pickle.load(filehandle) if ".keys" in filehandle.name else pd.read_pickle(filehandle)
)

class ShelveCacheTransformer(GenericCacheTransformer):
"""
A cache transformer based on Python's `shelve <https://docs.python.org/3/library/shelve.html>`_ library. Compares to the
chest-based cache, this transformer MUST be closed before cached instances can be seen by other instances.
"""
def __init__(self, inner, **kwargs):
super().__init__(inner, **kwargs)
filename = os.path.join(self.destdir, "shelve")
import shelve
if os.path.exists(filename) and os.path.getsize(filename) == 0:
warn("Cache file exists but has 0 size - perhaps a previous transformer cache should have been closed")
self.chest = shelve.open(filename)

def flush(self):
self.chest.sync()

def close(self):
self.chest.close()

CACHE_STORES={
"shelve" : ShelveCacheTransformer,
"chest" : ChestCacheTransformer
}

def of(
inner : TransformerBase,
on : List[str] = ["qid"],
store : str= DEFAULT_CACHE_STORE, **kwargs
) -> GenericCacheTransformer:
"""
Returns a transformer that caches the inner transformer.
Arguments:
inner(TransformerBase): which transformer should be cached
on(List[str]): which attributes to use as keys when caching
store(str): name of a cache type, either "shelve" or "chest". Defaults to "shelve".
"""
if not store in CACHE_STORES:
raise ValueError("cache store type %s unknown, known types %s" % (store, list(CACHE_STORES.keys())))
clz = CACHE_STORES[store]
return clz(inner, on=on, **kwargs)


44 changes: 27 additions & 17 deletions pyterrier/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .model import add_ranks
from . import tqdm
import deprecation
from typing import Iterable
from typing import Iterable, List

LAMBDA = lambda:0
def is_lambda(v):
Expand Down Expand Up @@ -151,24 +151,24 @@ def transform_gen(self, input : pd.DataFrame, batch_size=1) -> pd.DataFrame:

def search(self, query : str, qid : str = "1", sort=True) -> pd.DataFrame:
"""
Method for executing a transformer (pipeline) for a single query.
Returns a dataframe with the results for the specified query. This
is a utility method, and most uses are expected to use the transform()
method passing a dataframe.
Method for executing a transformer (pipeline) for a single query.
Returns a dataframe with the results for the specified query. This
is a utility method, and most uses are expected to use the transform()
method passing a dataframe.

Arguments:
- query(str): String form of the query to run
- qid(str): the query id to associate to this request. defaults to 1.
- sort(bool): ensures the results are sorted by descending rank (defaults to True)
Arguments:
- query(str): String form of the query to run
- qid(str): the query id to associate to this request. defaults to 1.
- sort(bool): ensures the results are sorted by descending rank (defaults to True)

Example::
Example::

bm25 = pt.BatchRetrieve(index, wmodel="BM25")
res = bm25.search("example query")
bm25 = pt.BatchRetrieve(index, wmodel="BM25")
res = bm25.search("example query")

# is equivalent to
queryDf = pd.DataFrame([["1", "example query"]], columns=["qid", "query"])
res = bm25.transform(queryDf)
# is equivalent to
queryDf = pd.DataFrame([["1", "example query"]], columns=["qid", "query"])
res = bm25.transform(queryDf)


"""
Expand All @@ -179,6 +179,17 @@ def search(self, query : str, qid : str = "1", sort=True) -> pd.DataFrame:
rtr = rtr.sort_values(["qid", "rank"], ascending=[True,True])
return rtr

def cache(self, on : List[str] = ['qid'], **kwargs):
"""
Provides an instance of this pipeline that caches results.

Arguments:
- on(List[str]): List of attributes that define the key to cache on.
- store(str): ensures the results are sorted by descending rank (defaults to True). Either "shelve" or "chest" are supported. Defaults to "shelve".
"""
from .cache import of
return of(self, on, **kwargs)

def compile(self):
"""
Rewrites this pipeline by applying of the Matchpy rules in rewrite_rules. Pipeline
Expand Down Expand Up @@ -263,8 +274,7 @@ def __xor__(self, right):
return ConcatenateTransformer(self, right)

def __invert__(self):
from .cache import ChestCacheTransformer
return ChestCacheTransformer(self)
return self.cache()

def __hash__(self):
return hash(repr(self))
Expand Down
Loading