diff --git a/pyterrier/text.py b/pyterrier/text.py index 3ec9914f..1f7ea627 100644 --- a/pyterrier/text.py +++ b/pyterrier/text.py @@ -1,4 +1,5 @@ from pyterrier.transformer import TransformerBase +from . import Transformer from pyterrier.datasets import IRDSDataset import more_itertools from collections import defaultdict @@ -248,16 +249,16 @@ def slidingWindow(sequence : list, winSize : int, step : int) -> list: return [x for x in list(more_itertools.windowed(sequence,n=winSize, step=step)) if x[-1] is not None] def snippets( - text_scorer_pipe : TransformerBase, + text_scorer_pipe : Transformer, text_attr : str = "text", summary_attr : str = "summary", num_psgs : int = 5, - joinstr : str ='...') -> TransformerBase: + joinstr : str ='...') -> Transformer: """ Applies query-biased summarisation (snippet), by applying the specified text scoring pipeline. Arguments: - - text_scorer_pipe(TransformerBase): the pipeline for scoring passages in response to the query. Normally this applies passaging. + - text_scorer_pipe(Transformer): the pipeline for scoring passages in response to the query. Normally this applies passaging. - text_attr(str): what is the name of the attribute that contains the text of the document - summary_attr(str): what is the name of the attribute that should contain the query-biased summary for that document - num_psgs(int): how many passages to select for the summary of each document