Skip to content

Commit

Permalink
save_dir for pt.Experiment #163 (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald authored Dec 9, 2021
1 parent 515ac3b commit 5a69676
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 35 deletions.
42 changes: 40 additions & 2 deletions docs/experiments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,44 @@ This provides a dataframe where each row is the performance of a given system fo

NB: For brevity, we only show the top 5 rows of the returned table.

Saving and Reusing Results
~~~~~~~~~~~~~~~~~~~~~~~~~~

For some research tasks, it is considered good practice to save your results files when conducting experiments. This allows
several advantages:

- It permits additional evaluation (e.g. more measures, more signifiance tests) without re-applying potentially slow transformer pipelines.
- It allows transformer results to be made available for other experiments, perhaps as a virtual data appendix in a paper.

Saving can be enabled by adding the ``save_dir`` as a kwarg to pt.Experiment::

pt.Experiment(
[tfidf, bm25],
dataset.get_topics(),
dataset.get_qrels(),
eval_metrics=["map", "recip_rank"],
names=["TF_IDF", "BM25"],
save_dir="./",
)

This will save two files, namely, TF_IDF.res.gz and BM25.res.gz to the current directory. If these files already exist,
they will be "reused", i.e. loaded and evaluated in preference to application of the tfidf and/or bm25 transformers.
If experiments are being conducted on multiple different topic sets, care should be taken to ensure that previous
results for a different topic set are not reused for evaluation.

If a transformer has been updated, outdated results files can be mistakenly used. To prevent this, set the ``save_mode``
kwarg to ``"overwrite"``::

pt.Experiment(
[tfidf, bm25],
dataset.get_topics(),
dataset.get_qrels(),
eval_metrics=["map", "recip_rank"],
names=["TF_IDF", "BM25"],
save_dir="./",
save_mode="overwrite"
)

Missing Topics and/or Qrels
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -156,9 +194,9 @@ such as in sparsely labeled datasets or shared tasks that choose to omit some to
Qids that appear in qrels but no in topics can happen when running a subset of topics for testing purposes
(e.g., ``topics.head(5)``).

The ``filter_by_qrels`` and ``fitler_by_topics`` parameters control the behaviour of an experiment when topics and qrels
The ``filter_by_qrels`` and ``filter_by_topics`` parameters control the behaviour of an experiment when topics and qrels
do not perfectly overlap. When ``filter_by_qrels=True``, topics are filtered down to only the ones that have qids in the
qrels. Similarly, when ``fitler_by_topics=True``, qrels are filtered down to only the ones that have qids in the topics.
qrels. Similarly, when ``filter_by_topics=True``, qrels are filtered down to only the ones that have qids in the topics.

For example, consier topics that include qids ``A`` and ``B`` and qrels that include ``B`` and ``C``. The results with
each combination of settings are:
Expand Down
27 changes: 19 additions & 8 deletions pyterrier/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def finalized_autoopen(path: str, mode: str):
"""
return _finalized_open_base(path, mode, autoopen)

def ok_filename(fname) -> bool:
"""
Checks to see if a filename is valid.
"""
BAD_CHARS = ':"%/<>^|?' + os.sep
for c in BAD_CHARS:
if c in fname:
return False
return True

def touch(fname, mode=0o666, dir_fd=None, **kwargs):
"""
Expand Down Expand Up @@ -192,14 +201,15 @@ def _read_results_trec(filename):
df["score"] = df["score"].astype(float)
return df

def write_results(res, filename, format="trec", **kwargs):
def write_results(res, filename, format="trec", append=False, **kwargs):
"""
Write a results dataframe to a file.
Parameters:
res (DataFrame): A results dataframe, with usual columns of qid, docno etc
filename (str): The filename of the file to be written. Compressed files are handled automatically.
format (str): The format of the results file: one of "trec", "letor", "minimal"
append (bool): Append to an existing file. Defaults to False.
**kwargs (dict): Other arguments for the internal method
Supported Formats:
Expand All @@ -212,22 +222,23 @@ def write_results(res, filename, format="trec", **kwargs):
raise ValueError("Format %s not known, supported types are %s" % (format, str(SUPPORTED_RESULTS_FORMATS.keys())))
# convert generators to results
res = coerce_dataframe(res)
return SUPPORTED_RESULTS_FORMATS[format][1](res, filename, **kwargs)
return SUPPORTED_RESULTS_FORMATS[format][1](res, filename, append=append, **kwargs)

def _write_results_trec(res, filename, run_name="pyterrier"):
def _write_results_trec(res, filename, run_name="pyterrier", append=False):
res_copy = res.copy()[["qid", "docno", "rank", "score"]]
res_copy.insert(1, "Q0", "Q0")
res_copy.insert(5, "run_name", run_name)
res_copy.to_csv(filename, sep=" ", header=False, index=False)
res_copy.to_csv(filename, sep=" ", mode='a' if append else 'w', header=False, index=False)

def _write_results_minimal(res, filename, run_name="pyterrier"):
def _write_results_minimal(res, filename, run_name="pyterrier", append=False):
res_copy = res.copy()[["qid", "docno", "rank"]]
res_copy.to_csv(filename, sep="\t", header=False, index=False)
res_copy.to_csv(filename, sep="\t", mode='a' if append else 'w', header=False, index=False)

def _write_results_letor(res, filename, qrels=None, default_label=0):
def _write_results_letor(res, filename, qrels=None, default_label=0, append=False):
if qrels is not None:
res = res.merge(qrels, on=['qid', 'docno'], how='left').fillna(default_label)
with autoopen(filename, "wt") as f:
mode='wa' if append else 'wt'
with autoopen(filename, mode) as f:
for row in res.itertuples():
values = row.features
label = row.label if qrels is not None else default_label
Expand Down
96 changes: 75 additions & 21 deletions pyterrier/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ def _run_and_evaluate(
qrels: pd.DataFrame,
metrics : MEASURES_TYPE,
pbar = None,
save_mode = None,
save_file = None,
perquery : bool = False,
batch_size = None,
backfill_qids : Sequence[str] = None):

from .io import read_results, write_results

if pbar is None:
from . import tqdm
pbar = tqdm(disable=True)
Expand All @@ -139,6 +143,14 @@ def _run_and_evaluate(
from timeit import default_timer as timer
runtime = 0
num_q = qrels['query_id'].nunique()
if save_file is not None and os.path.exists(save_file):
if save_mode == "reuse":
system = read_results(save_file)
elif save_mode == "overwrite":
os.remove(save_file)
else:
raise ValueError("Unknown save_file argument '%s', valid options are 'reuse' or 'overwrite'" % save_mode)

# if its a DataFrame, use it as the results
if isinstance(system, pd.DataFrame):
res = system
Expand All @@ -162,6 +174,10 @@ def _run_and_evaluate(
endtime = timer()
runtime = (endtime - starttime) * 1000.

# write results to save_file; we can be sure this file does not exist
if save_file is not None:
write_results(res, save_file)

res = coerce_dataframe_types(res)

if len(res) == 0:
Expand All @@ -181,25 +197,36 @@ def _run_and_evaluate(
starttime = timer()
evalMeasuresDict = {}
remaining_qrel_qids = set(qrels.query_id)
for i, (res, batch_topics) in enumerate( system.transform_gen(topics, batch_size=batch_size, output_topics=True)):
if len(res) == 0:
raise ValueError("batch of %d topics, but no results received in batch %d from %s" % (len(batch_topics), i, str(system) ) )
endtime = timer()
runtime += (endtime - starttime) * 1000.
res = coerce_dataframe_types(res)
batch_qids = set(batch_topics.qid)
batch_qrels = qrels[qrels.query_id.isin(batch_qids)] # filter qrels down to just the qids that appear in this batch
remaining_qrel_qids.difference_update(batch_qids)
batch_backfill = [qid for qid in backfill_qids if qid in batch_qids] if backfill_qids is not None else None
evalMeasuresDict.update(_ir_measures_to_dict(
ir_measures.iter_calc(metrics, batch_qrels, res.rename(columns=_irmeasures_columns)),
metrics,
rev_mapping,
num_q,
perquery=True,
backfill_qids=batch_backfill))
pbar.update()
starttime = timer()
try:
for i, (res, batch_topics) in enumerate( system.transform_gen(topics, batch_size=batch_size, output_topics=True)):
if len(res) == 0:
raise ValueError("batch of %d topics, but no results received in batch %d from %s" % (len(batch_topics), i, str(system) ) )
endtime = timer()
runtime += (endtime - starttime) * 1000.

# write results to save_file; we will append for subsequent batches
if save_file is not None:
write_results(res, save_file, append=True)

res = coerce_dataframe_types(res)
batch_qids = set(batch_topics.qid)
batch_qrels = qrels[qrels.query_id.isin(batch_qids)] # filter qrels down to just the qids that appear in this batch
remaining_qrel_qids.difference_update(batch_qids)
batch_backfill = [qid for qid in backfill_qids if qid in batch_qids] if backfill_qids is not None else None
evalMeasuresDict.update(_ir_measures_to_dict(
ir_measures.iter_calc(metrics, batch_qrels, res.rename(columns=_irmeasures_columns)),
metrics,
rev_mapping,
num_q,
perquery=True,
backfill_qids=batch_backfill))
pbar.update()
starttime = timer()
except:
# if an error is thrown, we need to clean up our existing file
if save_file is not None and os.path.exits(save_file):
os.remove(save_file)
raise
if remaining_qrel_qids:
# there are some qids in the qrels that were not in the topics. Get the default values for these and update evalMeasuresDict
missing_qrels = qrels[qrels.query_id.isin(remaining_qrel_qids)]
Expand Down Expand Up @@ -241,6 +268,8 @@ def Experiment(
highlight : str = None,
round : Union[int,Dict[str,int]] = None,
verbose : bool = False,
save_dir : str = None,
save_mode : str = 'reuse',
**kwargs):
"""
Allows easy comparison of multiple retrieval transformer pipelines using a common set of topics, and
Expand All @@ -263,6 +292,11 @@ def Experiment(
filter_by_qrels(bool): If True, will drop topics from the topics dataframe that have qids not appearing in the qrels dataframe.
filter_by_topics(bool): If True, will drop topics from the qrels dataframe that have qids not appearing in the topics dataframe.
perquery(bool): If True return each metric for each query, else return mean metrics across all queries. Default=False.
save_dir(str): If set to the name of a directory, the results of each transformer will be saved in TREC-formatted results file, whose
filename is based on the systems names (as specified by ``names`` kwarg). If the file exists and ``save_mode`` is set to "reuse", then the file
will be used for evaluation rather than the transformer. Default is None, such that saving and loading from files is disabled.
save_mode(str): Defines how existing files are used when ``save_dir`` is set. If set to "reuse", then files will be preferred
over transformers for evaluation. If set to "overwrite", existing files will be replaced. Default is "reuse".
dataframe(bool): If True return results as a dataframe, else as a dictionary of dictionaries. Default=True.
baseline(int): If set to the index of an item of the retr_system list, will calculate the number of queries
improved, degraded and the statistical significance (paired t-test p value) for each measure.
Expand Down Expand Up @@ -365,6 +399,19 @@ def _apply_round(measure, value):
elif len(names) != len(retr_systems):
raise ValueError("names should be the same length as retr_systems")

# validate save_dir and resulting filenames
if save_dir is not None:
if not os.path.exists(save_dir):
raise ValueError("save_dir %s does not exist" % save_dir)
if not os.path.isdir(save_dir):
raise ValueError("save_dir %s is not a directory" % save_dir)
from .io import ok_filename
for n in names:
if not ok_filename(n):
raise ValueError("Name contains bad characters and save_dir is set, name is %s" % n)
if len(set(names)) < len(names):
raise ValueError("save_dir is set, but names are not unique. Use names= to set unique names")

all_topic_qids = topics["qid"].values

evalsRows=[]
Expand All @@ -376,7 +423,7 @@ def _apply_round(measure, value):
mrt_needed = True
eval_metrics.remove("mrt")

# run and evaluate each system
# progress bar construction
from . import tqdm
tqdm_args={
'disable' : not verbose,
Expand All @@ -389,12 +436,19 @@ def _apply_round(measure, value):
tqdm_args['total'] = (len(topics) / batch_size) * len(retr_systems)

with tqdm(**tqdm_args) as pbar:
for name,system in zip(names, retr_systems):
# run and evaluate each system
for name, system in zip(names, retr_systems):
save_file = None
if save_dir is not None:
save_file = os.path.join(save_dir, "%s.res.gz" % name)

time, evalMeasuresDict = _run_and_evaluate(
system, topics, qrels, eval_metrics,
perquery=perquery or baseline is not None,
batch_size=batch_size,
backfill_qids=all_topic_qids if perquery else None,
save_file=save_file,
save_mode=save_mode,
pbar=pbar)

if baseline is not None:
Expand Down
25 changes: 21 additions & 4 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import warnings
import math
from pyterrier.measures import *
from .base import BaseTestCase
from .base import TempDirTestCase

class TestExperiment(BaseTestCase):
class TestExperiment(TempDirTestCase):

def test_irm_APrel2(self):
topics = pd.DataFrame([["q1", "q1"], ["q2", "q1"] ], columns=["qid", "query"])
Expand Down Expand Up @@ -127,16 +127,33 @@ def test_wrong(self):


def test_mrt(self):
index = pt.datasets.get_dataset("vaswani").get_index()
brs = [
pt.BatchRetrieve(pt.datasets.get_dataset("vaswani").get_index(), wmodel="DPH"),
pt.BatchRetrieve(pt.datasets.get_dataset("vaswani").get_index(), wmodel="BM25")
pt.BatchRetrieve(index, wmodel="DPH"),
pt.BatchRetrieve(index, wmodel="BM25")
]
topics = pt.datasets.get_dataset("vaswani").get_topics().head(10)
qrels = pt.datasets.get_dataset("vaswani").get_qrels()
pt.Experiment(brs, topics, qrels, eval_metrics=["map", "mrt"])
pt.Experiment(brs, topics, qrels, eval_metrics=["map", "mrt"], highlight="color")
pt.Experiment(brs, topics, qrels, eval_metrics=["map", "mrt"], baseline=0, highlight="color")

def test_save(self):
index = pt.datasets.get_dataset("vaswani").get_index()
brs = [
pt.BatchRetrieve(index, wmodel="DPH"),
pt.BatchRetrieve(index, wmodel="BM25")
]
topics = pt.datasets.get_dataset("vaswani").get_topics().head(10)
qrels = pt.datasets.get_dataset("vaswani").get_qrels()
df1 = pt.Experiment(brs, topics, qrels, eval_metrics=["map", "mrt"], save_dir=self.test_dir)
# check save_dir files are there
self.assertTrue(os.path.exists(os.path.join(self.test_dir, "BR(DPH).res.gz")))
self.assertTrue(os.path.exists(os.path.join(self.test_dir, "BR(BM25).res.gz")))
df2 = pt.Experiment(brs, topics, qrels, eval_metrics=["map", "mrt"], save_dir=self.test_dir)
# a successful experiment using save_dir should be faster
self.assertTrue(df2.iloc[0]["mrt"] < df1.iloc[0]["mrt"])

def test_empty(self):
df1 = pt.new.ranked_documents([[1]]).head(0)
t1 = pt.transformer.SourceTransformer(df1)
Expand Down

0 comments on commit 5a69676

Please sign in to comment.