From 246d5ae53389e43fa4b25106b16bcfdfe340a9ce Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Sun, 23 Jan 2022 14:05:46 +0000 Subject: [PATCH] fixed bug with is_transformer --- pyterrier/transformer.py | 2 +- tests/test_transformer.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 tests/test_transformer.py diff --git a/pyterrier/transformer.py b/pyterrier/transformer.py index ee52bea5..4f4e9bb9 100644 --- a/pyterrier/transformer.py +++ b/pyterrier/transformer.py @@ -16,7 +16,7 @@ def is_function(v): return isinstance(v, types.FunctionType) def is_transformer(v): - if isinstance(v, TransformerBase): + if isinstance(v, Transformer): return True return False diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..1b6c7687 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,19 @@ +import pyterrier as pt +import unittest +from .base import BaseTestCase +import os +import pandas as pd + +class TestTransformer(BaseTestCase): + + def test_is_transformer(self): + class MyTransformer1(pt.Transformer): + pass + class MyTransformer2(pt.transformer.TransformerBase): + pass + class MyTransformer3(pt.transformer.IterDictIndexerBase): + pass + class MyTransformer4(pt.transformer.EstimatorBase): + pass + for T in [MyTransformer1, MyTransformer2, MyTransformer3, MyTransformer4]: + self.assertTrue(pt.transformer.is_transformer(T()))