Skip to content

Commit

Permalink
wip - transform_iter return changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald committed Sep 4, 2024
1 parent 2c8552a commit 304943c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
9 changes: 6 additions & 3 deletions pyterrier/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,18 @@ def index(self, iter : Iterable[dict], batch_size=100):

def gen():
for batch in chunked(iter, batch_size):
batch_df = prev_transformer.transform_iter(batch)
for row in batch_df.itertuples(index=False):
yield row._asdict()
yield from prev_transformer.transform_iter(batch)
return last_transformer.index(gen())

def transform(self, topics):
for m in self.models:
topics = m.transform(topics)
return topics

def transform_iter(self, topics):
for m in self.models:
topics = m.transform_iter(topics)
return topics

def fit(self, topics_or_res_tr, qrels_tr, topics_or_res_va=None, qrels_va=None):
"""
Expand Down
35 changes: 25 additions & 10 deletions pyterrier/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,17 @@ class Transformer:
"""
Base class for all transformers. Implements the various operators ``>>`` ``+`` ``*`` ``|`` ``&``
as well as ``search()`` for executing a single query and ``compile()`` for rewriting complex pipelines into more simples ones.
Its expected that either ``.transform()`` or ``.transform_iter()`` be implemented by any class extending this - this rule
does not apply for indexers, which instead implement ``.index()``.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._transform_implemented = type(self).transform != Transformer.transform
self._transform_iter_implemented = type(self).transform_iter != Transformer.transform_iter
# we cant test for either self._transform_implemented or self._transform_iter_implemented here, due to indexers

@staticmethod
def identity() -> 'Transformer':
"""
Expand Down Expand Up @@ -93,16 +102,21 @@ def transform(self, topics_or_res : pd.DataFrame) -> pd.DataFrame:
Abstract method for all transformations. Typically takes as input a Pandas
DataFrame, and also returns one.
"""
pass
if not self._transform_iter_implemented:
raise NotImplementedError("You need to implement either .transform() and .transform_iter() in %s" % str(type(self)))
return pd.DataFrame(self.transform_iter(topics_or_res.to_records(orient='dict')))

def transform_iter(self, input: Iterable[dict]) -> pd.DataFrame:
def transform_iter(self, input: Iterable[dict]) -> Iterable[dict]:
"""
Method that proesses an iter-dict by instantiating it as a dataframe and calling transform().
Returns the DataFrame returned by transform(). This can be a handier version of transform()
that avoids constructing a dataframe by hand. Alo used in the implementation of index() on a composed
pipeline.
Method that proesses an iter-dict by instantiating it as a dataframe and calling ``transform()``.
Returns an Iterable[dict] equivalent to the DataFrame returned by ``transform()``. This can be a
handier version of ``transform()`` that avoids constructing a dataframe by hand. Also used in the
implementation of ``index()`` on a composed pipeline.
"""
return self.transform(pd.DataFrame(list(input)))
if not self._transform_implemented:
raise NotImplementedError("You need to implement either .transform() and .transform_iter() in %s" % str(type(self)))

return self.transform(pd.DataFrame(list(input))).to_records(orient='dict')

def transform_gen(self, input : pd.DataFrame, batch_size=1, output_topics=False) -> Iterator[pd.DataFrame]:
"""
Expand Down Expand Up @@ -214,10 +228,11 @@ def set_parameter(self, name : str, value):
raise ValueError(('Invalid parameter name %s for transformer %s. '+
'Check the list of available parameters') %(name, str(self)))

def __call__(self, input : Union[pd.DataFrame, Iterable[dict]]) -> pd.DataFrame:
def __call__(self, input : Union[pd.DataFrame, Iterable[dict]]) ->Union[pd.DataFrame, Iterable[dict]]:
"""
Sets up a default method for every transformer, which is aliased to transform() (for DataFrames)
or transform_iter() (for iterable dictionaries) depending on the type of input.
Sets up a default method for every transformer, which is aliased to ``transform()`` (for DataFrames)
or ``transform_iter()`` (for iterable dictionaries) depending on the type of input. The return type
matches the input type.
"""
if isinstance(input, pd.DataFrame):
return self.transform(input)
Expand Down

0 comments on commit 304943c

Please sign in to comment.