Skip to content

Commit

Permalink
Merge pull request #454 from terrier-org/java_backport_453
Browse files Browse the repository at this point in the history
Backport #453 to Java branch
  • Loading branch information
seanmacavaney authored Aug 16, 2024
2 parents bc6cb51 + d249116 commit 1c46751
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 44 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ name: Continuous Testing
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
pull_request: {}

jobs:
build:
Expand Down Expand Up @@ -79,6 +78,13 @@ jobs:
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
#flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: RM3 unit tests
env:
TERRIER_VERSION: ${{ matrix.terrier }}
run: |
pytest -p no:faulthandler tests/test_rewrite_rm3.py
# Hide underlying Jnius problem by disabling faulthandler: https://github.com/pytest-dev/pytest/issues/7634
- name: Flash unit tests
env:
TERRIER_VERSION: ${{ matrix.terrier }}
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Welcome to PyTerrier's documentation!
PyTerrier_T5 <https://github.com/terrierteam/pyterrier_t5>
PyTerrier_GenRank <https://github.com/emory-irlab/pyterrier_genrank>
PyTerrier_ColBERT <https://github.com/terrierteam/pyterrier_colbert>
PyTerrier_ChatNoir <https://github.com/chatnoir-eu/chatnoir-pyterrier>
PyTerrier_ANCE <https://github.com/terrierteam/pyterrier_ance>
PyTerrier_doc2query <https://github.com/terrierteam/pyterrier_doc2query>
PyTerrier_DeepCT <https://github.com/terrierteam/pyterrier_deepct>
Expand Down
7 changes: 5 additions & 2 deletions pyterrier/java/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ def legacy_init(version=None, mem=None, packages=[], jvm_opts=[], redirect_io=Tr
deprecated_calls.append(f'pt.java.set_log_level({logging!r})')

for package in boot_packages:
pt.java.add_package(*package.split(':')) # format: org:package:version:filetype (where version and filetype are optional)
deprecated_calls.append(f'pt.java.add_package({package!r})')
# format: org:package:version:filetype (where version and filetype are optional)
pkg_split = package.split(':')
pkg_string = ", ".join(f'{w!r}' for w in pkg_split)
pt.java.add_package(*pkg_split)
deprecated_calls.append(f'pt.java.add_package({pkg_string})')

for opt in jvm_opts:
pt.java.add_option(opt)
Expand Down
9 changes: 6 additions & 3 deletions pyterrier/terrier/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@

@pt.java.before_init
def set_version(version: Optional[str] = None):
configure['terrier_version'] = version
if version is not None:
configure['terrier_version'] = version


@pt.java.before_init
def set_helper_version(version: Optional[str] = None):
configure['helper_version'] = version
if version is not None:
configure['helper_version'] = version


@pt.java.before_init
def set_prf_version(version: Optional[str] = None):
configure['prf_version'] = version
if version is not None:
configure['prf_version'] = version


class TerrierJavaInit(pt.java.JavaInitializer):
Expand Down
27 changes: 18 additions & 9 deletions pyterrier/terrier/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class QueryExpansion(pt.Transformer):
'''

def __init__(self, index_like, fb_terms=10, fb_docs=3, qeclass="org.terrier.querying.QueryExpansion", verbose=0, properties={}, **kwargs):
def __init__(self, index_like, fb_terms=10, fb_docs=3, qeclass="org.terrier.querying.QueryExpansion", verbose=0, properties={}, requires_scores=False, **kwargs):
super().__init__(**kwargs)
self.verbose = verbose
if isinstance(qeclass, str):
Expand All @@ -185,6 +185,7 @@ def __init__(self, index_like, fb_terms=10, fb_docs=3, qeclass="org.terrier.quer
self.fb_terms = fb_terms
self.fb_docs = fb_docs
self.manager = pt.terrier.J.ManagerFactory._from_(self.indexref)
self.requires_scores = requires_scores

def __reduce__(self):
return (
Expand Down Expand Up @@ -216,30 +217,38 @@ def __setstate__(self, d):

def _populate_resultset(self, topics_and_res, qid, index):

docids=None
scores=None
occurrences=None
docids = None
scores = None
occurrences = None
if "docid" in topics_and_res.columns:
# we need .tolist() as jnius cannot convert numpy arrays
docids = topics_and_res[topics_and_res["qid"] == qid]["docid"].values.tolist()
topics_and_res_for_qid = topics_and_res[topics_and_res["qid"] == qid]
docids = topics_and_res_for_qid["docid"].values.tolist()
scores = [0.0] * len(docids)
if self.requires_scores:
scores = topics_and_res_for_qid["score"].values.tolist()
occurrences = [0] * len(docids)

elif "docno" in topics_and_res.columns:
docnos = topics_and_res[topics_and_res["qid"] == qid]["docno"].values
topics_and_res_for_qid = topics_and_res[topics_and_res["qid"] == qid]
docnos = topics_and_res_for_qid["docno"].values
docids = []
scores = []
_scores = [0.0] * len(docids)
if self.requires_scores:
_scores = topics_and_res_for_qid["score"].values.tolist()

occurrences = []
metaindex = index.getMetaIndex()
skipped = 0
for docno in docnos:
for docno, docscore in zip(docnos, _scores):
docid = metaindex.getDocument("docno", docno)
if docid == -1:
skipped +=1
assert docid != -1, "could not match docno" + docno + " to a docid for query " + qid
docids.append(docid)
scores.append(0.0)
occurrences.append(0)
scores.append(docscore)
if skipped > 0:
if skipped == len(docnos):
warn("*ALL* %d feedback docnos for qid %s could not be found in the index" % (skipped, qid))
Expand Down Expand Up @@ -384,7 +393,7 @@ def __init__(self, *args, fb_terms=10, fb_docs=3, fb_lambda=0.6, **kwargs):
rm = pt.terrier.J.RM3()
self.fb_lambda = fb_lambda
kwargs["qeclass"] = rm
super().__init__(*args, fb_terms=fb_terms, fb_docs=fb_docs, **kwargs)
super().__init__(*args, fb_terms=fb_terms, fb_docs=fb_docs, requires_scores=True, **kwargs)

def __getstate__(self):
rtr = super().__getstate__()
Expand Down
58 changes: 30 additions & 28 deletions tests/test_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,39 +221,41 @@ def _sdm(self, freq):
pt.Evaluate(br_sdm.transform(t), dataset.get_qrels(), metrics=["map"])["map"],
places=4)

def test_rm3(self):
dataset = pt.datasets.get_dataset("vaswani")
indexref = dataset.get_index()

qe = pt.rewrite.RM3(indexref)
br = pt.BatchRetrieve(indexref)

queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])
res = br.transform(queriesIn)

queriesOut = qe.transform(res)
self.assertEqual(len(queriesOut), 1)
query = queriesOut.iloc[0]["query"]
#self.assertTrue("compact^1.82230972" in query)
self.assertTrue("applypipeline:off " in query)
# RM3 cannot be tested with current jnius, as it must be placed into the boot classpath
# As workaround for the moment, those RM3 tests are implemented in a single file tests/test_rewrite_rm3.py that is skipped when executing the complete pipeline, but are executed when run in isolation.
# def test_rm3(self):
# dataset = pt.datasets.get_dataset("vaswani")
# indexref = dataset.get_index()

# qe = pt.rewrite.RM3(indexref)
# br = pt.BatchRetrieve(indexref)

# queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])
# res = br.transform(queriesIn)

# queriesOut = qe.transform(res)
# self.assertEqual(len(queriesOut), 1)
# query = queriesOut.iloc[0]["query"]
# #self.assertTrue("compact^1.82230972" in query)
# self.assertTrue("applypipeline:off " in query)

pipe = br >> qe >> br
# pipe = br >> qe >> br

# lets go faster, we only need 18 topics. qid 16 had a tricky case
t = dataset.get_topics().head(18)
# # lets go faster, we only need 18 topics. qid 16 had a tricky case
# t = dataset.get_topics().head(18)

all_qe_res = pipe.transform(t)
map_pipe = pt.Evaluate(all_qe_res, dataset.get_qrels(), metrics=["map"])["map"]
# all_qe_res = pipe.transform(t)
# map_pipe = pt.Evaluate(all_qe_res, dataset.get_qrels(), metrics=["map"])["map"]

br_qe = pt.BatchRetrieve(indexref,
controls={"qe":"on"},
properties={"querying.processes" : "terrierql:TerrierQLParser,parsecontrols:TerrierQLToControls,"\
+"parseql:TerrierQLToMatchingQueryTerms,matchopql:MatchingOpQLParser,applypipeline:ApplyTermPipeline,"\
+"sd:DependenceModelPreProcess,localmatching:LocalManager$ApplyLocalMatching,qe:RM3,"\
+"labels:org.terrier.learning.LabelDecorator,filters:LocalManager$PostFilterProcess"})
map_qe = pt.Evaluate(br_qe.transform(t), dataset.get_qrels(), metrics=["map"])["map"]
# br_qe = pt.BatchRetrieve(indexref,
# controls={"qe":"on"},
# properties={"querying.processes" : "terrierql:TerrierQLParser,parsecontrols:TerrierQLToControls,"\
# +"parseql:TerrierQLToMatchingQueryTerms,matchopql:MatchingOpQLParser,applypipeline:ApplyTermPipeline,"\
# +"sd:DependenceModelPreProcess,localmatching:LocalManager$ApplyLocalMatching,qe:RM3,"\
# +"labels:org.terrier.learning.LabelDecorator,filters:LocalManager$PostFilterProcess"})
# map_qe = pt.Evaluate(br_qe.transform(t), dataset.get_qrels(), metrics=["map"])["map"]

self.assertAlmostEqual(map_qe, map_pipe, places=2)
# self.assertAlmostEqual(map_qe, map_pipe, places=2)

def test_linear_terrierql(self):
pipe = pt.apply.query(lambda row: "az") >> pt.rewrite.linear(0.75, 0.25)
Expand Down
174 changes: 174 additions & 0 deletions tests/test_rewrite_rm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import pandas as pd
import pyterrier as pt
import os
from matchpy import *
from .base import TempDirTestCase
import pytest

if not pt.java.started():
terrier_version = os.environ.get("TERRIER_VERSION", None)
terrier_helper_version = os.environ.get("TERRIER_HELPER_VERSION", None)
pt.java.set_log_level('DEBUG')
pt.terrier.set_version(terrier_version)
pt.terrier.set_helper_version(terrier_helper_version)
pt.terrier.set_prf_version('rm_tiebreak-SNAPSHOT')
pt.java.init() # optional, forces java initialisation
TERRIER_PRF_ON_CLASSPATH = True
else:
TERRIER_PRF_ON_CLASSPATH = False


def normalize_term_weights(term_weights, digits=7):
ret = ''
for i in term_weights.split():
if '^' in i:
i = i.split('^')
i = i[0] + '^' + i[1][:digits]
ret += ' ' + i
return ret.strip()

class TestRewriteRm3(TempDirTestCase):
"""This is a set of unit tests for RM3 that can currently not run in the complete test suite, as the "com.github.terrierteam:terrier-prf:-SNAPSHOT" would have to be added to the boot classpath.
As workaround, the RM3 tests that can not be executed within the complete test suite are added to this dedicated file, so that they can be executed in isolation by running pytest tests/test_rewrite_rm3.py.
"""

@pytest.mark.skipif(not TERRIER_PRF_ON_CLASSPATH, reason="This test only works in isolation when terrier-prf is on the jnius classpath.")
def test_rm3_expansion_for_query_compact_on_tf_idf(self):
# top-retrieval results of TF-IDF and BM25 below change, so the RM3 weights differ
expected = 'applypipeline:off equip^0.037346367 ferrit^0.027371584 modul^0.027371584 suppli^0.037346367 design^0.056739070 microwav^0.027371584 anod^0.037346367 unit^0.037346367 compact^0.674414337 stabil^0.037346367'

indexref = pt.datasets.get_dataset("vaswani").get_index()
queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])

qe = pt.rewrite.RM3(indexref)
br = pt.BatchRetrieve(indexref, wmodel='TF_IDF')

actual = qe.transform(br.transform(queriesIn))

self.assertEqual(len(actual), 1)
self.assertEqual(normalize_term_weights(expected), normalize_term_weights(actual.iloc[0]["query"]))

@pytest.mark.skipif(not TERRIER_PRF_ON_CLASSPATH, reason="This test only works in isolation when terrier-prf is on the jnius classpath.")
def test_rm3_expansion_for_query_compact_on_bm25(self):
# top-retrieval results of BM25 and TF-IDF above change, so the RM3 weights differ
expected = 'applypipeline:off equip^0.040264644 ferrit^0.025508024 modul^0.025508024 suppli^0.040264644 design^0.051008239 microwav^0.025508024 anod^0.040264644 unit^0.040264644 compact^0.671144485 stabil^0.040264644'

indexref = pt.datasets.get_dataset("vaswani").get_index()
queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])

qe = pt.rewrite.RM3(indexref)
br = pt.BatchRetrieve(indexref, wmodel='BM25')

actual = qe.transform(br.transform(queriesIn))

self.assertEqual(len(actual), 1)
self.assertEqual(normalize_term_weights(expected), normalize_term_weights(actual.iloc[0]["query"]))

@pytest.mark.skipif(not TERRIER_PRF_ON_CLASSPATH, reason="This test only works in isolation when terrier-prf is on the jnius classpath.")
def test_axiomatic_qe_expansion_for_query_compact_on_bm25(self):
# just ensure that AxiomaticQE results do not change
expected = 'applypipeline:off compact^1.000000000'

indexref = pt.datasets.get_dataset("vaswani").get_index()
queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])

qe = pt.rewrite.AxiomaticQE(indexref)
br = pt.BatchRetrieve(indexref, wmodel='BM25')

actual = qe.transform(br.transform(queriesIn))

self.assertEqual(len(actual), 1)
self.assertEqual(expected, actual.iloc[0]["query"])

def test_kl_qe_expansion_for_query_compact_on_bm25(self):
# just ensure that KLQueryExpansion results do not change
expected = 'applypipeline:off compact^1.840895333 design^0.348370740 equip^0.000000000 purpos^0.000000000 instrument^0.000000000 ferrit^0.000000000 anod^0.000000000 aircraft^0.000000000 microwav^0.000000000 sideband^0.000000000'

indexref = pt.datasets.get_dataset("vaswani").get_index()
queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])

qe = pt.rewrite.KLQueryExpansion(indexref)
br = pt.BatchRetrieve(indexref, wmodel='BM25')

actual = qe.transform(br.transform(queriesIn))

self.assertEqual(len(actual), 1)
self.assertEqual(expected, actual.iloc[0]["query"])

def test_bo1_qe_expansion_for_query_compact_on_bm25(self):
# just ensure that Bo1QueryExpansion results do not change
expected = 'applypipeline:off compact^1.822309726 design^0.287992096 equip^0.000000000 purpos^0.000000000 instrument^0.000000000 ferrit^0.000000000 anod^0.000000000 aircraft^0.000000000 microwav^0.000000000 sideband^0.000000000'

indexref = pt.datasets.get_dataset("vaswani").get_index()
queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])

qe = pt.rewrite.Bo1QueryExpansion(indexref)
br = pt.BatchRetrieve(indexref, wmodel='BM25')

actual = qe.transform(br.transform(queriesIn))

self.assertEqual(len(actual), 1)
self.assertEqual(expected, actual.iloc[0]["query"])

def test_dfr_qe_expansion_for_query_compact_on_bm25(self):
# just ensure that DFRQueryExpansion results do not change
expected = 'applypipeline:off compact^1.822309726 design^0.287992096 equip^0.000000000 purpos^0.000000000 instrument^0.000000000 ferrit^0.000000000 anod^0.000000000 aircraft^0.000000000 microwav^0.000000000 sideband^0.000000000'

indexref = pt.datasets.get_dataset("vaswani").get_index()
queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])

qe = pt.rewrite.DFRQueryExpansion(indexref)
br = pt.BatchRetrieve(indexref, wmodel='BM25')

actual = qe.transform(br.transform(queriesIn))

self.assertEqual(len(actual), 1)
self.assertEqual(expected, actual.iloc[0]["query"])

@pytest.mark.skipif(not TERRIER_PRF_ON_CLASSPATH, reason="This test only works in isolation when terrier-prf is on the jnius classpath.")
def test_rm3_end_to_end(self):
"""An end-to-end test, contrasting the smaller tests (that fail faster) from above.
"""
dataset = pt.datasets.get_dataset("vaswani")
indexref = dataset.get_index()

qe = pt.rewrite.RM3(indexref)
br = pt.BatchRetrieve(indexref)

queriesIn = pd.DataFrame([["1", "compact"]], columns=["qid", "query"])
res = br.transform(queriesIn)

queriesOut = qe.transform(res)
self.assertEqual(len(queriesOut), 1)
query = queriesOut.iloc[0]["query"]
#self.assertTrue("compact^1.82230972" in query)
self.assertTrue("applypipeline:off " in query)

pipe = br >> qe >> br

# lets go faster, we only need 18 topics. qid 16 had a tricky case
t = dataset.get_topics().head(18)

all_qe_res = pipe.transform(t)
map_pipe = pt.Evaluate(all_qe_res, dataset.get_qrels(), metrics=["map"])["map"]

br_qe = pt.BatchRetrieve(indexref,
controls={"qe":"on"},
properties={"querying.processes" : "terrierql:TerrierQLParser,parsecontrols:TerrierQLToControls,"\
+"parseql:TerrierQLToMatchingQueryTerms,matchopql:MatchingOpQLParser,applypipeline:ApplyTermPipeline,"\
+"sd:DependenceModelPreProcess,localmatching:LocalManager$ApplyLocalMatching,qe:RM3,"\
+"labels:org.terrier.learning.LabelDecorator,filters:LocalManager$PostFilterProcess"})
map_qe = pt.Evaluate(br_qe.transform(t), dataset.get_qrels(), metrics=["map"])["map"]

self.assertAlmostEqual(map_qe, map_pipe, places=2)

@pytest.mark.skipif(not TERRIER_PRF_ON_CLASSPATH, reason="This test only works in isolation when terrier-prf is on the jnius classpath.")
def test_scoring_rm3_qe(self):
expected = 'applypipeline:off fox^0.600000024'
input = pd.DataFrame([["q1", "fox", "d1", "all the fox were fox", 3], ["q1", "fox", "d2", "brown fox jumps", 2]], columns=["qid", "query", "docno", "body", "score"])
scorer = pt.terrier.retriever.TextIndexProcessor(pt.rewrite.RM3, takes="docs", returns="queries")
rtr = scorer(input)
self.assertTrue("qid" in rtr.columns)
self.assertTrue("query" in rtr.columns)
self.assertTrue("docno" not in rtr.columns)
self.assertTrue(expected, rtr.iloc[0]["query"])

0 comments on commit 1c46751

Please sign in to comment.