Skip to content

Commit

Permalink
Terrier weighting models expressed in Python (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald authored Aug 10, 2021
1 parent 507ae91 commit 6374c8b
Show file tree
Hide file tree
Showing 14 changed files with 408 additions and 15 deletions.
30 changes: 30 additions & 0 deletions docs/terrier-retrieval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,33 @@ By default, PyTerrier is configured for indexing and retrieval in English. See
`our notebook <https://github.com/terrier-org/pyterrier/blob/master/examples/notebooks/non_en_retrieval.ipynb>`_
(`colab <https://colab.research.google.com/github/terrier-org/pyterrier/blob/master/examples/notebooks/non_en_retrieval.ipynb>`_)
for details on how to configure PyTerrier in other languages.

Custom Weighting Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Normally, weighting models are specified as a string class names. Terrier then loads the Java class of that name (it will search
the `org.terrier.matching.models package <http://terrier.org/docs/current/javadoc/org/terrier/matching/models/package-summary.html>`_
unless the class name is fully qualified (e.g. `"com.example.MyTF"`).

If you have your own Java weighting model instance (which extends the
`WeightingModel abstract class <http://terrier.org/docs/current/javadoc/org/terrier/matching/models/WeightingModel.html>`_,
you can load it and pass it directly to BatchRetrieve::

mymodel = pt.autoclass("com.example.MyTF")()
retr = pt.BatchRetrieve(indexref, wmodel=mymodel)

More usefully, it is possible to express a weighting model entirely in Python, as a function or a lambda expression, that can be
used by Terrier for scoring. In this example, we create a Terrier BatchRetrieve instance that scores based solely on term frequency::

Tf = lambda keyFreq, posting, entryStats, collStats: posting.getFrequency()
retr = pt.BatchRetrieve(indexref, wmodel=Tf)

All functions passed must accept 4 arguments, as follows:

- keyFrequency(float): the weight of the term in the query, usually 1 except during PRF.
- posting(`Posting <http://terrier.org/docs/current/javadoc/org/terrier/structures/postings/Posting.html>`_): access to the information about the occurrence of the term in the current document (frequency, document length etc).
- entryStats(`EntryStatistics <http://terrier.org/docs/current/javadoc/org/terrier/structures/EntryStatistics.html>`_): access to the information about the occurrence of the term in the whole index (document frequency, etc.).
- collStats(`CollectionStatistics <http://terrier.org/docs/current/javadoc/org/terrier/structures/CollectionStatistics.html>`_): access to the information about the index as a whole (number of documents, etc).

Note that due to the overheads of continually traversing the JNI boundary, using a Python function for scoring has a marked efficiency overhead. This is probably too slow for retrieval using most indices of any significant size,
but allows simple explanation of weighting models and exploratory weighting model development.
2 changes: 1 addition & 1 deletion pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def init(version=None, mem=None, packages=[], jvm_opts=[], redirect_io=True, log
os.mkdir(HOME_DIR)

# get the initial classpath for the JVM
classpathTrJars = setup_terrier(HOME_DIR, version, boot_packages=boot_packages)
classpathTrJars = setup_terrier(HOME_DIR, version, boot_packages=boot_packages, helper_version="0.0.6")

if is_windows():
if "JAVA_HOME" in os.environ:
Expand Down
66 changes: 62 additions & 4 deletions pyterrier/batchretrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from warnings import warn
from .index import Indexer
from .datasets import Dataset
from .transformer import TransformerBase, Symbol
from .transformer import TransformerBase, Symbol, is_lambda
from .model import coerce_queries_dataframe, FIRST_RANK
import deprecation
import concurrent
Expand All @@ -25,6 +25,35 @@ def _matchop(query):
return True
return False

def _function2wmodel(function):
from . import autoclass
from jnius import PythonJavaClass, java_method

class PythonWmodelFunction(PythonJavaClass):
__javainterfaces__ = ['org/terrier/python/CallableWeightingModel$Callback']

def __init__(self, fn):
super(PythonWmodelFunction, self).__init__()
self.fn = fn

@java_method('(DLorg/terrier/structures/postings/Posting;Lorg/terrier/structures/EntryStatistics;Lorg/terrier/structures/CollectionStatistics;)D', name='score')
def score(self, keyFreq, posting, entryStats, collStats):
return self.fn(keyFreq, posting, entryStats, collStats)

@java_method('()Ljava/nio/ByteBuffer;')
def serializeFn(self):
import dill as pickle
#see https://github.com/SeldonIO/alibi/issues/447#issuecomment-881552005
from dill import extend
extend(use_dill=False)
byterep = pickle.dumps(self.fn)
byterep = autoclass("java.nio.ByteBuffer").wrap(byterep)
return byterep

callback = PythonWmodelFunction(function)
wmodel = autoclass("org.terrier.python.CallableWeightingModel")( callback )
return callback, wmodel

def _mergeDicts(defaults, settings):
KV = defaults.copy()
if settings is not None and len(settings) > 0:
Expand Down Expand Up @@ -138,7 +167,7 @@ def from_dataset(dataset : Union[str,Dataset],

#: default_properties(dict): stores the default properties
default_properties = {
"querying.processes": "terrierql:TerrierQLParser,parsecontrols:TerrierQLToControls,parseql:TerrierQLToMatchingQueryTerms,matchopql:MatchingOpQLParser,applypipeline:ApplyTermPipeline,localmatching:LocalManager$ApplyLocalMatching,qe:QueryExpansion,labels:org.terrier.learning.LabelDecorator,filters:LocalManager$PostFilterProcess",
"querying.processes": "terrierql:TerrierQLParser,parsecontrols:TerrierQLToControls,parseql:TerrierQLToMatchingQueryTerms,matchopql:MatchingOpQLParser,applypipeline:ApplyTermPipeline,context_wmodel:org.terrier.python.WmodelFromContextProcess,localmatching:LocalManager$ApplyLocalMatching,qe:QueryExpansion,labels:org.terrier.learning.LabelDecorator,filters:LocalManager$PostFilterProcess",
"querying.postfilters": "decorate:SimpleDecorate,site:SiteFilter,scope:Scope",
"querying.default.controls": "wmodel:DPH,parsecontrols:on,parseql:on,applypipeline:on,terrierql:on,localmatching:on,filters:on,decorate:on",
"querying.allowed.controls": "scope,qe,qemodel,start,end,site,scope,applypipeline",
Expand Down Expand Up @@ -168,6 +197,7 @@ def __init__(self, index_location, controls=None, properties=None, metadata=["do
self.metadata = metadata
self.threads = threads
self.RequestContextMatching = autoclass("org.terrier.python.RequestContextMatching")
self.search_context = {}

if props is None:
importProps()
Expand All @@ -176,8 +206,21 @@ def __init__(self, index_location, controls=None, properties=None, metadata=["do

self.controls = _mergeDicts(BatchRetrieve.default_controls, controls)
if wmodel is not None:
self.controls["wmodel"] = wmodel

from .transformer import is_lambda, is_function
if isinstance(wmodel, str):
self.controls["wmodel"] = wmodel
elif is_lambda(wmodel) or is_function(wmodel):
callback, wmodelinstance = _function2wmodel(wmodel)
#save the callback instance in this object to prevent being GCd by Python
self._callback = callback
self.search_context['context_wmodel'] = wmodelinstance
self.controls['context_wmodel'] = 'on'
elif isinstance(wmodel, autoclass("org.terrier.matching.models.WeightingModel")):
self.search_context['context_wmodel'] = wmodel
self.controls['context_wmodel'] = 'on'
else:
raise ValueError("Unknown parameter type passed for wmodel argument: %s" % str(wmodel))

if self.threads > 1:
warn("Multi-threaded retrieval is experimental, YMMV.")
assert check_version(5.5), "Terrier 5.5 is required for multi-threaded retrieval"
Expand Down Expand Up @@ -226,6 +269,7 @@ def __reduce__(self):

def __getstate__(self):
return {
'context' : self.search_context,
'controls' : self.controls,
'properties' : self.properties,
'metadata' : self.metadata,
Expand All @@ -234,6 +278,7 @@ def __getstate__(self):
def __setstate__(self, d):
self.controls = d["controls"]
self.metadata = d["metadata"]
self.search_context = d["context"]
self.properties.update(d["properties"])
for key,value in d["properties"].items():
self.appSetup.setProperty(key, str(value))
Expand All @@ -251,6 +296,9 @@ def _retrieve_one(self, row, input_results=None, docno_provided=False, docid_pro
for control, value in self.controls.items():
srq.setControl(control, str(value))

for key, value in self.search_context.items():
srq.setContextObject(key, value)

# this is needed until terrier-core issue #106 lands
if "applypipeline:off" in query:
srq.setControl("applypipeline", "off")
Expand Down Expand Up @@ -538,6 +586,7 @@ def __init__(self, index_location, features, controls=None, properties=None, thr
# record the weighting model
self.wmodel = None
if "wmodel" in kwargs:
assert isinstance(kwargs["wmodel"], str), "Non-string weighting models not yet supported by FBR"
self.wmodel = kwargs["wmodel"]
if "wmodel" in controls:
self.wmodel = controls["wmodel"]
Expand All @@ -560,6 +609,7 @@ def __getstate__(self):
'metadata' : self.metadata,
'features' : self.features,
'wmodel' : self.wmodel
#TODO consider the context state?
}

def __setstate__(self, d):
Expand All @@ -570,6 +620,14 @@ def __setstate__(self, d):
self.properties.update(d["properties"])
for key,value in d["properties"].items():
self.appSetup.setProperty(key, str(value))
#TODO consider the context state?

@staticmethod
def from_dataset(dataset : Union[str,Dataset],
variant : str = None,
version='latest',
**kwargs):
return _from_dataset(dataset, variant=variant, version=version, clz=FeaturesBatchRetrieve, **kwargs)

@staticmethod
def from_dataset(dataset : Union[str,Dataset],
Expand Down
60 changes: 59 additions & 1 deletion pyterrier/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
stdout_ref = None
stderr_ref = None
TERRIER_PKG = "org.terrier"

SAVED_FNS=[]

@deprecation.deprecated(deprecated_in="0.1.3",
# remove_id="",
Expand All @@ -23,6 +23,31 @@ def new_indexref(s):
from . import IndexRef
return IndexRef.of(s)

def new_wmodel(bytes):
from . import autoclass
serUtils = autoclass("org.terrier.python.Serialization")
return serUtils.deserialize(bytes, autoclass("org.terrier.utility.ApplicationSetup").getClass("org.terrier.matching.models.WeightingModel") )

def new_callable_wmodel(byterep):
import dill as pickle
from dill import extend
#see https://github.com/SeldonIO/alibi/issues/447#issuecomment-881552005
extend(use_dill=False)
fn = pickle.loads(byterep)
#we need to prevent these functions from being GCd.
global SAVED_FNS
SAVED_FNS.append(fn)
from .batchretrieve import _function2wmodel
callback, wmodel = _function2wmodel(fn)
SAVED_FNS.append(callback)
#print("Stored lambda fn %s and callback in SAVED_FNS, now %d stored" % (str(fn), len(SAVED_FNS)))
return wmodel

def javabytebuffer2array(buffer):
def unsign(signed):
return signed + 256 if signed < 0 else signed
return bytearray([ unsign(buffer.get(offset)) for offset in range(buffer.capacity()) ])

def setup_jnius():
from jnius import protocol_map # , autoclass

Expand Down Expand Up @@ -73,6 +98,39 @@ def index_ref_reduce(self):
'__getstate__' : lambda self : None,
}


# handles the pickling of WeightingModel classes, which are themselves usually Serializable in Java
def wmodel_reduce(self):
from . import autoclass
serUtils = autoclass("org.terrier.python.Serialization")
serialized = bytes(serUtils.serialize(self))
return (
new_wmodel,
(serialized, ),
None
)

protocol_map["org.terrier.matching.models.WeightingModel"] = {
'__reduce__' : wmodel_reduce,
'__getstate__' : lambda self : None,
}

def callable_wmodel_reduce(self):
from . import autoclass
# get bytebuffer representation of lambda
# convert bytebyffer to python bytearray
bytesrep = javabytebuffer2array(self.scoringClass.serializeFn())
return (
new_callable_wmodel,
(bytesrep, ),
None
)

protocol_map["org.terrier.python.CallableWeightingModel"] = {
'__reduce__' : callable_wmodel_reduce,
'__getstate__' : lambda self : None,
}

def _index_add(self, other):
from . import autoclass
fields_1 = self.getCollectionStatistics().getNumberOfFields()
Expand Down
7 changes: 5 additions & 2 deletions pyterrier/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
LAMBDA = lambda:0
def is_lambda(v):
return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__


def is_function(v):
return isinstance(v, types.FunctionType)

def is_transformer(v):
if isinstance(v, TransformerBase):
return True
Expand All @@ -29,7 +32,7 @@ def get_transformer(v):
return v
if is_lambda(v):
return ApplyGenericTransformer(v)
if isinstance(v, types.FunctionType):
if is_function(v):
return ApplyGenericTransformer(v)
if isinstance(v, pd.DataFrame):
return SourceTransformer(v)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ ir_datasets>=0.3.2
jinja2
statsmodels
ir_measures==0.1.4
dill
8 changes: 4 additions & 4 deletions terrier-python-helper/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

<groupId>org.terrier</groupId>
<artifactId>terrier-python-helper</artifactId>
<version>0.0.5</version>
<version>0.0.6</version>
<url>http://terrier.org</url>
<name>terrier-python-helper</name>
<description>Python bindings for the Terrier IR platform</description>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
<build.terrier.version>5.3</build.terrier.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<build.terrier.version>5.5</build.terrier.version>
</properties>

<organization>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.terrier.python;
import org.terrier.matching.models.WeightingModel;
import org.terrier.structures.postings.Posting;
import org.terrier.structures.EntryStatistics;
import org.terrier.structures.CollectionStatistics;
import java.io.IOException;
import java.io.ObjectStreamException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;

/** A weighting model class that includes a Callback interface that can be implemented in Python.
*/
public class CallableWeightingModel extends WeightingModel {

public static interface Callback {
public double score(double keyFrequency, Posting posting, EntryStatistics entryStats, CollectionStatistics collStats);
public default double score1(Posting posting, EntryStatistics entryStats, CollectionStatistics collStats) {
return score(1.0d, posting, entryStats, collStats);
}
public ByteBuffer serializeFn();
}

public Callback scoringClass;

private CallableWeightingModel() {}

public CallableWeightingModel(Callback _scoringClass) {
scoringClass = _scoringClass;
}

@Override
public double score(Posting p) {
return scoringClass.score(super.keyFrequency, p, super.es, super.cs);
}

@Override
public double score(double a, double b) {
throw new UnsupportedOperationException();
}

@Override
public String getInfo() {
return this.getClass().getSimpleName();
}

private void writeObject(java.io.ObjectOutputStream out) throws IOException {}
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {}
private void readObjectNoData() throws ObjectStreamException {}
}
Loading

0 comments on commit 6374c8b

Please sign in to comment.