Skip to content

Commit

Permalink
use __new__ for checking impls
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald committed Sep 5, 2024
1 parent be5eb5f commit 6f788d0
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions pyterrier/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ class Transformer:
does not apply for indexers, which instead implement ``.index()``.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._transform_implemented = type(self).transform != Transformer.transform
self._transform_iter_implemented = type(self).transform_iter != Transformer.transform_iter
# we cant test for either self._transform_implemented or self._transform_iter_implemented here, due to indexers

def __new__(cls, *args, **kwargs):
if not issubclass(cls, Indexer) and cls.transform == Transformer.transform and cls.transform_iter == Transformer.transform_iter:
raise NotImplementedError("You need to implement either .transform() or .transform_iter() in %s" % str(cls))
return super().__new__(cls)

@staticmethod
def identity() -> 'Transformer':
"""
Expand Down Expand Up @@ -102,8 +101,7 @@ def transform(self, topics_or_res : pd.DataFrame) -> pd.DataFrame:
Abstract method for all transformations. Typically takes as input a Pandas
DataFrame, and also returns one.
"""
if not self._transform_iter_implemented:
raise NotImplementedError("You need to implement either .transform() and .transform_iter() in %s" % str(type(self)))
# We should have no recursive transform <-> transform_iter problem, due to the __new__ check, UNLESS .transform() is called on an Indexer.
return pd.DataFrame(self.transform_iter(topics_or_res.to_dict(orient='records')))

def transform_iter(self, input: Iterable[dict]) -> Iterable[dict]:
Expand All @@ -113,9 +111,7 @@ def transform_iter(self, input: Iterable[dict]) -> Iterable[dict]:
handier version of ``transform()`` that avoids constructing a dataframe by hand. Also used in the
implementation of ``index()`` on a composed pipeline.
"""
if not self._transform_implemented:
raise NotImplementedError("You need to implement either .transform() and .transform_iter() in %s" % str(type(self)))

# We should have no recursive transform <-> transform_iter problem, due to the __new__ check, UNLESS .transform() is called on an Indexer.
return self.transform(pd.DataFrame(list(input))).to_dict(orient='records')

def transform_gen(self, input : pd.DataFrame, batch_size=1, output_topics=False) -> Iterator[pd.DataFrame]:
Expand Down

0 comments on commit 6f788d0

Please sign in to comment.