Skip to content

Commit

Permalink
pt.debug utility transformers, issues #219 #220 (PR #221)
Browse files Browse the repository at this point in the history
* fix empty_Q()
* addresses pt.apply.new_column fails on empty dataframes #219
* fixes for version='snapshot'
* commit of debug transformer #220
  • Loading branch information
cmacdonald authored Sep 6, 2021
1 parent 4ac9221 commit 02c4b64
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 6 deletions.
13 changes: 13 additions & 0 deletions docs/debug.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _pyterrier.debug:

pyterrier.debug - Transformers for Debugging
--------------------------------------------

Its very easy to write complex pipelines with PyTerrier. Sometimes you need to inspect dataframes in the middle of a pipeline.
The pt.debug transformers display the columns or the data, and can be inserted into pipelines during development.

Debug Methods
=============

.. automodule:: pyterrier.debug
:members:
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Welcome to PyTerrier's documentation!
apply
anserini
new
debug

.. toctree::
:maxdepth: 1
Expand Down
3 changes: 2 additions & 1 deletion pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
anserini = None
apply = None
cache = None
debug = None
index = None
io = None
ltr = None
Expand Down Expand Up @@ -136,7 +137,7 @@ def init(version=None, mem=None, packages=[], jvm_opts=[], redirect_io=True, log
from .apply import _apply
globals()['apply'] = _apply()

for sub_module_name in ['anserini', 'cache', 'index', 'io', 'measures', 'model', 'new', 'ltr', 'parallel', 'pipelines', 'rewrite', 'text', 'transformer']:
for sub_module_name in ['anserini', 'cache', 'debug', 'index', 'io', 'measures', 'model', 'new', 'ltr', 'parallel', 'pipelines', 'rewrite', 'text', 'transformer']:
globals()[sub_module_name] = importlib.import_module('.' + sub_module_name, package='pyterrier')

# append the python helpers
Expand Down
2 changes: 1 addition & 1 deletion pyterrier/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ def generic_apply(name, *args, drop=False, **kwargs) -> TransformerBase:
fn = args[0]
args=[]
def _new_column(df):
df[name] = df.apply(fn, axis=1)
df[name] = df.apply(fn, axis=1, result_type='reduce')
return df
return ApplyGenericTransformer(_new_column, *args, **kwargs)
79 changes: 79 additions & 0 deletions pyterrier/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from .transformer import TransformerBase
from typing import List

def print_columns(by_query : bool = False, message : str = None) -> TransformerBase:
"""
Returns a transformer that can be inserted into pipelines that can print the column names of the dataframe
at this stage in the pipeline:
Arguments:
- by_query(bool): whether to display for each query. Defaults to False.
- message(str): whether to display a message before printing. Defaults to None, which means no message. This
is useful when print_columns() is being used multiple times within a pipeline
Example::
pipe = (
bm25
>> pt.debug.print_columns()
>> pt.rewrite.RM3()
>> pt.debug.print_columns()
bm25
When the above pipeline is executed, two sets of columns will be displayed
- `["qid", "query", "docno", "rank", "score"]` - the output of BM25, a ranking of documents
- `["qid", "query", "query_0"]` - the output of RM3, a reformulated query
"""
import pyterrier as pt
def _do_print(df):
if message is not None:
print(message)
print(df.columns)
return df
return pt.apply.by_query(_do_print) if by_query else pt.apply.generic(_do_print)

def print_rows(
by_query : bool = True,
jupyter: bool = True,
head : int = 2,
message : str = None,
columns : List[str] = None) -> TransformerBase:
"""
Returns a transformer that can be inserted into pipelines that can print some of the dataframe
at this stage in the pipeline:
Arguments:
- by_query(bool): whether to display for each query. Defaults to True.
- jupyter(bool): Whether to use IPython's display function to display the dataframe. Defaults to True.
- head(int): The number of rows to display. None means all rows.
- columns(List[str]): Limit the columns for which data is displayed. Default of None displays all columns.
- message(str): whether to display a message before printing. Defaults to None, which means no message. This
is useful when print_rows() is being used multiple times within a pipeline
Example::
pipe = (
bm25
>> pt.debug.print_rows()
>> pt.rewrite.RM3()
>> pt.debug.print_rows()
bm25
"""
import pyterrier as pt
def _do_print(df):
if message is not None:
print(message)
render = df if head is None else df.head(head)
if columns is not None:
render = render[columns]
if jupyter:
from IPython.display import display
display(render)
else:
print(render)
return df
return pt.apply.by_query(_do_print) if by_query else pt.apply.generic(_do_print)
12 changes: 9 additions & 3 deletions pyterrier/mavenresolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ def downloadfile(orgName, packageName, version, file_path, artifact="jar", force
filename = packageName + "-" + version + suffix + "." + ext

filelocation = orgName + "/" + packageName + "/" + version + "/" + filename

if os.path.isfile(os.path.join(file_path, filename)) and not force_download:
return os.path.join(file_path, filename)

target_file = os.path.join(file_path, filename)
file_exists = os.path.isfile(target_file)
if file_exists:
if not force_download:
return target_file
else:
# ensure that wget doesnt put the file in a different name
os.remove(target_file)

# check local Maven repo, and use that if it exists
from os.path import expanduser
Expand Down
2 changes: 1 addition & 1 deletion pyterrier/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def empty_Q() -> pd.DataFrame:
"""
Returns an empty dataframe with columns `["qid", "query"]`.
"""
return pd.DataFrame([[]], columns=["qid", "query"])
return pd.DataFrame(columns=["qid", "query"])

def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]] = None, **others) -> pd.DataFrame:
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_make_columns(self):
rtr = p(testDF)
self.assertTrue("BlaB" in rtr.columns)
self.assertEqual(rtr.iloc[0]["BlaB"], 2)
emptyQs = pt.new.empty_Q()
rtr = p(emptyQs)
self.assertTrue("BlaB" in rtr.columns)

def test_rename_columns(self):
from pyterrier.transformer import TransformerBase
Expand Down
5 changes: 5 additions & 0 deletions tests/test_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

class TestModel(BaseTestCase):

def test_empty(self):
df = pt.new.empty_Q()
self.assertTrue("qid" in df.columns)
self.assertTrue("query" in df.columns)

def test_newR1(self):
df = pt.new.ranked_documents([[1]])
self.assertEqual(1, len(df))
Expand Down

0 comments on commit 02c4b64

Please sign in to comment.