diff --git a/pyterrier/cache.py b/pyterrier/cache.py index ed081d7e..7c76e80a 100644 --- a/pyterrier/cache.py +++ b/pyterrier/cache.py @@ -5,11 +5,13 @@ import os from os import path CACHE_DIR = None +DEFAULT_CACHE_STORE = "shelve" #or "chest" import pandas as pd import pickle from functools import partial import datetime from warnings import warn +from typing import List DEFINITION_FILE = ".transformer" @@ -46,11 +48,18 @@ def list_cache(): if path.exists(def_file): with open(def_file, "r") as f: elem["transformer"] = f.readline() - elem["size"] = sum(d.stat().st_size for d in os.scandir(dir) if d.is_file()) - elem["size_str"] = sizeof_fmt(elem["size"]) - elem["queries"] = len(os.listdir(dir)) -2 #subtract .keys and DEFINITION_FILE - elem["lastmodified"] = path.getmtime(dir) + shelve_file = path.join(dir, "shelve.db") + if path.exists(shelve_file): + elem["size"] = path.getsize(shelve_file) + elem["lastmodified"] = path.getmtime(shelve_file) + else: + #we assume it is a chest + elem["size"] = sum(d.stat().st_size for d in os.scandir(dir) if d.is_file()) + elem["queries"] = len(os.listdir(dir)) -2 #subtract .keys and DEFINITION_FILE + elem["lastmodified"] = path.getmtime(dir) + elem["lastmodified_str"] = datetime.datetime.fromtimestamp(elem["lastmodified"]).strftime('%Y-%m-%dT%H:%M:%S') + elem["size_str"] = sizeof_fmt(elem["size"]) rtr[dirname] = elem return rtr @@ -61,15 +70,18 @@ def clear_cache(): import shutil shutil.rmtree(CACHE_DIR) -class ChestCacheTransformer(TransformerBase): + +class GenericCacheTransformer(TransformerBase): """ A transformer that cache the results of the consituent (inner) transformer. This is instantiated using the `~` operator on any transformer. - Caching is unqiue based on the configuration of the pipeline, as read by executing - retr() on the pipeline. Caching lookup is based on the qid, so any change in query + Caching is based on the configuration of the pipeline, as read by executing + repr() on the pipeline. Caching lookup is by default based on the qid, so any change in query _formulation_ will not be reflected in a cache's results. + Caching lookup can be changed by altering the `on` attribute in the cache object. + Example Usage:: dataset = pt.get_dataset("trec-robust-2004") @@ -90,39 +102,42 @@ class ChestCacheTransformer(TransformerBase): In the above example, we use the `~` operator on the first pass retrieval using BM25, but not on the 2nd pass retrieval, as the query formulation will differ during the second pass. - Caching is not supported for re-ranking transformers. + """ - def __init__(self, inner, **kwargs): + def __init__(self, inner, on=["qid"], verbose=False, debug=False, **kwargs): super().__init__(**kwargs) - on="qid" + self.on = on self.inner = inner self.disable = False + self.hits = 0 + self.requests = 0 + self.debug = debug + self.verbose = verbose + if CACHE_DIR is None: init() # we take the md5 of the __repr__ of the pipeline to make a unique identifier for the pipeline # all different pipelines should return unique __repr_() values, as these are intended to be # unambiguous - trepr = repr(self.inner) - if "object at 0x" in trepr: - warn("Cannot cache pipeline %s has a component has not overridden __repr__" % trepr) - self.disable = True + self.trepr = repr(self.inner) + if "object at 0x" in self.trepr: + warn("Cannot cache pipeline %s across PyTerrier sessions, as it has a transient component, which has not overridden __repr__()" % self.trepr) + #return + #self.disable = True - uid = hashlib.md5( bytes(trepr, "utf-8") ).hexdigest() - destdir = path.join(CACHE_DIR, uid) - os.makedirs(destdir, exist_ok=True) - definition_file=path.join(destdir, DEFINITION_FILE) + uid = hashlib.md5( bytes(self.trepr, "utf-8") ).hexdigest() + self.destdir = path.join(CACHE_DIR, uid) + os.makedirs(self.destdir, exist_ok=True) + + definition_file=path.join(self.destdir, DEFINITION_FILE) if not path.exists(definition_file): + if self.debug: + print("Creating new cache store at %s for %s" % (self.destdir, self.trepr)) with open(definition_file, "w") as f: - f.write(trepr) - self.chest = Chest(path=destdir, - dump=lambda data, filename: pd.DataFrame.to_pickle(data, filename) if isinstance(data, pd.DataFrame) else pickle.dump(data, filename, protocol=1), - load=lambda filehandle: pickle.load(filehandle) if ".keys" in filehandle.name else pd.read_pickle(filehandle) - ) - self.hits = 0 - self.requests = 0 - + f.write(self.trepr) + def stats(self): return self.hits / self.requests if self.requests > 0 else 0 @@ -136,44 +151,128 @@ def __repr__(self): def __str__(self): return "Cache("+str(self.inner)+")" + def __del__(self): + self.close() + @property def NOCACHE(self): return self.inner + def flush(self): + self.chest.flush() + + def close(self): + pass + def transform(self, input_res): if self.disable: return self.inner.transform(input_res) - if "docid" in input_res.columns or "docno" in input_res.columns: - raise ValueError("Caching of %s for re-ranking is not supported. Caching currently only supports input dataframes with queries as inputs and cannot be used for re-rankers." % self.inner.__repr__()) + for col in self.on: + if col not in input_res.columns: + raise ValueError("Caching on %s, but did not find column %s among input columns %s" + % (str(self.on)), col, str(input_res.columns)) + for col in ["docno"]: + if col in input_res.columns and not col in self.on and len(self.on) == 1: + warn(("Caching on=%s, but found column %s among input columns %s. You may want " % (str(self.on)), col, str(input_res.columns) ) + + "to update the on attribute for the cache transformer") return self._transform_qid(input_res) def _transform_qid(self, input_res): + # output dataframes to /return/ rtr = [] + # input rows to execute on the inner transformer todo=[] - - # We cannot remove this iterrows() without knowing how to take named tuples into a dataframe - for index, row in input_res.iterrows(): - qid = str(row["qid"]) + import pyterrier as pt + iter = input_res.itertuples(index=False) + for row in pt.tqdm( + iter, + desc="%s lookups" % self, + unit='row', + total=len(input_res)) if self.verbose else iter: + # we calculate what we will key this cache on + key = ''.join([getattr(row, k) for k in self.on]) + qid = str(row.qid) self.requests += 1 try: - df = self.chest.get(qid, None) + df = self.chest.get(key, None) except: # occasionally we have file not founds, # lets remove from the cache and continue - del self.chest[qid] + del self.chest[key] df = None if df is None: - todo.append(row.to_frame().T) + if self.debug: + print("%s cache miss for key %s" % (self, key)) + todo.append(row) else: + if self.debug: + print("%s cache hit for key %s" % (self, key)) self.hits += 1 rtr.append(df) if len(todo) > 0: - todo_df = pd.concat(todo) + todo_df = pd.DataFrame(todo) todo_res = self.inner.transform(todo_df) - for row in todo_df.itertuples(): - qid = row.qid - this_query_res = todo_res[todo_res["qid"] == qid] - self.chest[qid] = this_query_res - rtr.append(this_query_res) - self.chest.flush() + for key_vals, group in todo_res.groupby(self.on): + key = ''.join(key_vals) + self.chest[key] = group + if self.debug: + print("%s caching %d results for key %s" % (self, len(group), key)) + rtr.append(todo_res) + self.flush() return pd.concat(rtr) + + +class ChestCacheTransformer(GenericCacheTransformer): + """ + A cache transformer based on `chest `_. + """ + def __init__(self, inner, **kwargs): + super().__init__(inner, **kwargs) + + self.chest = Chest(path=self.destdir, + dump=lambda data, filename: pd.DataFrame.to_pickle(data, filename) if isinstance(data, pd.DataFrame) else pickle.dump(data, filename, protocol=1), + load=lambda filehandle: pickle.load(filehandle) if ".keys" in filehandle.name else pd.read_pickle(filehandle) + ) + +class ShelveCacheTransformer(GenericCacheTransformer): + """ + A cache transformer based on Python's `shelve `_ library. Compares to the + chest-based cache, this transformer MUST be closed before cached instances can be seen by other instances. + """ + def __init__(self, inner, **kwargs): + super().__init__(inner, **kwargs) + filename = os.path.join(self.destdir, "shelve") + import shelve + if os.path.exists(filename) and os.path.getsize(filename) == 0: + warn("Cache file exists but has 0 size - perhaps a previous transformer cache should have been closed") + self.chest = shelve.open(filename) + + def flush(self): + self.chest.sync() + + def close(self): + self.chest.close() + +CACHE_STORES={ + "shelve" : ShelveCacheTransformer, + "chest" : ChestCacheTransformer +} + +def of( + inner : TransformerBase, + on : List[str] = ["qid"], + store : str= DEFAULT_CACHE_STORE, **kwargs + ) -> GenericCacheTransformer: + """ + Returns a transformer that caches the inner transformer. + Arguments: + inner(TransformerBase): which transformer should be cached + on(List[str]): which attributes to use as keys when caching + store(str): name of a cache type, either "shelve" or "chest". Defaults to "shelve". + """ + if not store in CACHE_STORES: + raise ValueError("cache store type %s unknown, known types %s" % (store, list(CACHE_STORES.keys()))) + clz = CACHE_STORES[store] + return clz(inner, on=on, **kwargs) + + \ No newline at end of file diff --git a/pyterrier/transformer.py b/pyterrier/transformer.py index b471600b..fdec046d 100644 --- a/pyterrier/transformer.py +++ b/pyterrier/transformer.py @@ -6,7 +6,7 @@ from .model import add_ranks from . import tqdm import deprecation -from typing import Iterable +from typing import Iterable, List LAMBDA = lambda:0 def is_lambda(v): @@ -151,24 +151,24 @@ def transform_gen(self, input : pd.DataFrame, batch_size=1) -> pd.DataFrame: def search(self, query : str, qid : str = "1", sort=True) -> pd.DataFrame: """ - Method for executing a transformer (pipeline) for a single query. - Returns a dataframe with the results for the specified query. This - is a utility method, and most uses are expected to use the transform() - method passing a dataframe. + Method for executing a transformer (pipeline) for a single query. + Returns a dataframe with the results for the specified query. This + is a utility method, and most uses are expected to use the transform() + method passing a dataframe. - Arguments: - - query(str): String form of the query to run - - qid(str): the query id to associate to this request. defaults to 1. - - sort(bool): ensures the results are sorted by descending rank (defaults to True) + Arguments: + - query(str): String form of the query to run + - qid(str): the query id to associate to this request. defaults to 1. + - sort(bool): ensures the results are sorted by descending rank (defaults to True) - Example:: + Example:: - bm25 = pt.BatchRetrieve(index, wmodel="BM25") - res = bm25.search("example query") + bm25 = pt.BatchRetrieve(index, wmodel="BM25") + res = bm25.search("example query") - # is equivalent to - queryDf = pd.DataFrame([["1", "example query"]], columns=["qid", "query"]) - res = bm25.transform(queryDf) + # is equivalent to + queryDf = pd.DataFrame([["1", "example query"]], columns=["qid", "query"]) + res = bm25.transform(queryDf) """ @@ -179,6 +179,17 @@ def search(self, query : str, qid : str = "1", sort=True) -> pd.DataFrame: rtr = rtr.sort_values(["qid", "rank"], ascending=[True,True]) return rtr + def cache(self, on : List[str] = ['qid'], **kwargs): + """ + Provides an instance of this pipeline that caches results. + + Arguments: + - on(List[str]): List of attributes that define the key to cache on. + - store(str): ensures the results are sorted by descending rank (defaults to True). Either "shelve" or "chest" are supported. Defaults to "shelve". + """ + from .cache import of + return of(self, on, **kwargs) + def compile(self): """ Rewrites this pipeline by applying of the Matchpy rules in rewrite_rules. Pipeline @@ -263,8 +274,7 @@ def __xor__(self, right): return ConcatenateTransformer(self, right) def __invert__(self): - from .cache import ChestCacheTransformer - return ChestCacheTransformer(self) + return self.cache() def __hash__(self): return hash(repr(self)) diff --git a/tests/test_cache.py b/tests/test_cache.py index 8964b6c6..e1a00b50 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -7,18 +7,87 @@ import shutil import os +def compare(df1, df2): + df1 = df1.sort_values(["qid", "rank"]) + df2 = df2.sort_values(["qid", "rank"]) + import numpy as np + for i, (rowA, rowB) in enumerate( zip(df1.itertuples(), df2.itertuples())): + for col in ["qid", "query", "docno", "score", "rank"]: + assert getattr(rowA, col) == getattr(rowB, col), (i,col, rowA, rowB) + if hasattr(rowA, "features") or hasattr(rowB, "features"): + assert np.array_equal(getattr(rowA, "features"), getattr(rowB, "features"), (i,"features", rowA, rowB)) + return True + class TestCache(TempDirTestCase): + def test_complex(self): + pt.cache.CACHE_DIR = self.test_dir + "/test_complex" + dataset = pt.get_dataset("vaswani") + index = dataset.get_index() + firstpassUB = pt.BatchRetrieve(index, wmodel="PL2") + features = [ + "SAMPLE", #ie PL2 + "WMODEL:BM25", + ] + stdfeatures = pt.FeaturesBatchRetrieve(index, features) + stage12 = firstpassUB >> stdfeatures + CfirstpassUB = ~firstpassUB + Cstdfeatures = ~stdfeatures + Cstdfeatures.on=['qid', 'docno'] + Cstage12 = CfirstpassUB >> Cstdfeatures + COstage12 = ~stage12 + + num_topics = 5 + test_topics = dataset.get_topics().head(num_topics) + + #res0 is the ground truth + res0 = stage12(test_topics) + Cstage12(test_topics) + res1 = Cstage12(test_topics).reset_index(drop=True) + self.assertEqual(num_topics, Cstage12[0].hits) + COstage12(test_topics) + res2 = COstage12(test_topics) + self.assertEqual(num_topics, COstage12.hits) + + self.assertTrue(compare(res1, res0)) + self.assertTrue(compare(res2, res0)) + + def test_cache_reranker(self): + pt.cache.CACHE_DIR = self.test_dir + "/test_cache_reranker" + class MyT(pt.transformer.TransformerBase): + def transform(self, docs): + docs = docs.copy() + docs["score"] = docs.apply(lambda doc_row: len(doc_row["text"]), axis=1) + return pt.model.add_ranks(docs) + def __repr__(self): + return "MyT" + p = MyT() + testDF = pd.DataFrame([["q1", "hello", "d1", "aa"]], columns=["qid", "query", "docno", "text"]) + rtr = p(testDF) + + cached = ~p + cached.on = ["qid", "text"] + #cached.debug = True + rtr2 = cached(testDF) + self.assertTrue(rtr.equals(rtr2)) + rtr3 = cached(testDF) + #print(rtr) + #print(rtr3) + self.assertTrue(rtr.equals(rtr3)) + self.assertEqual(cached.requests, 2) + self.assertEqual(cached.hits, 1) + def test_cache_br(self): - pt.cache.CACHE_DIR = self.test_dir + pt.cache.CACHE_DIR = self.test_dir + "/test_cache_br" import pandas as pd queries = pd.DataFrame([["q1", "chemical"]], columns=["qid", "query"]) br = pt.BatchRetrieve(pt.get_dataset("vaswani").get_index()) cache = ~br - self.assertEqual(0, len(cache.chest._keys)) + self.assertEqual(0, len(cache.chest.keys())) cache(queries) cache(queries) self.assertEqual(0.5, cache.stats()) + cache.close() #lets see if another cache of the same object would see the same cache entries. cache2 = ~br @@ -28,16 +97,17 @@ def test_cache_br(self): pt.cache.CACHE_DIR = None def test_cache_compose(self): - pt.cache.CACHE_DIR = self.test_dir + pt.cache.CACHE_DIR = self.test_dir + "/test_cache_compose" import pandas as pd queries = pd.DataFrame([["q1", "chemical"]], columns=["qid", "query"]) br1 = pt.BatchRetrieve(pt.get_dataset("vaswani").get_index(), wmodel="TF_IDF") br2 = pt.BatchRetrieve(pt.get_dataset("vaswani").get_index(), wmodel="BM25") cache = ~ (br1 >> br2) - self.assertEqual(0, len(cache.chest._keys)) + self.assertEqual(0, len(cache.chest.keys())) cache(queries) cache(queries) self.assertEqual(0.5, cache.stats()) + del(cache) #lets see if another cache of the same object would see the same cache entries. cache2 = ~(br1 >> br2) @@ -47,27 +117,33 @@ def test_cache_compose(self): pt.cache.CACHE_DIR = None def test_cache_compose_cache(self): - pt.cache.CACHE_DIR = self.test_dir + pt.cache.CACHE_DIR = self.test_dir + "/test_cache_compose_cache" import pandas as pd queries = pd.DataFrame([["q1", "chemical"]], columns=["qid", "query"]) br1 = pt.BatchRetrieve(pt.get_dataset("vaswani").get_index(), wmodel="TF_IDF") br2 = pt.BatchRetrieve(pt.get_dataset("vaswani").get_index(), wmodel="BM25") cache = ~ (~br1 >> br2) - self.assertEqual(0, len(cache.chest._keys)) + self.assertEqual(0, len(cache.chest.keys())) cache(queries) cache(queries) self.assertEqual(0.5, cache.stats()) + #this is required for shelve + cache.close() + #lets see if another cache of the same object would see the same cache entries. cache2 = ~(~br1 >> br2) + cache2.debug = True + #print("found keys in cache 2") + #print(cache2.chest.keys()) cache2(queries) + #print(cache2.hits) self.assertEqual(1, cache2.stats()) # check that the cache report works all_report = pt.cache.list_cache() self.assertTrue(len(all_report) > 0) report = list(all_report.values())[0] - self.assertEqual(1, report["queries"]) self.assertTrue("transformer" in report) self.assertTrue("size" in report) self.assertTrue("lastmodified" in report)