Skip to content

Commit

Permalink
Fix/test pass overrides (#918)
Browse files Browse the repository at this point in the history
* Fix test requiring both test_step and test_end

* Add test

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
MattPainter01 and williamFalcon committed Feb 25, 2020
1 parent 2b5293d commit 6b667b1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def evaluate(self, model, dataloaders, max_batches, test=False):

def run_evaluation(self, test=False):
# when testing make sure user defined a test step
if test and not (self.is_overriden('test_step')):
m = '''You called `.test()` without defining model's `.test_step()`.
if test and not (self.is_overriden('test_step') or self.is_overriden('test_end')):
m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
Please define and try again'''
raise MisconfigurationException(m)

Expand Down
34 changes: 34 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,5 +782,39 @@ def test_trainer_min_steps_and_epochs(tmpdir):
trainer.current_epoch > 0, "Model did not train for at least min_steps"


def test_testpass_overrides(tmpdir):
hparams = tutils.get_hparams()
from pytorch_lightning.utilities.debugging import MisconfigurationException

class TestModelNoEnd(LightningTestModelBase):
def test_step(self, *args, **kwargs):
return {}

def test_dataloader(self):
return self.train_dataloader()

class TestModelNoStep(LightningTestModelBase):
def test_end(self, outputs):
return {}

def test_dataloader(self):
return self.train_dataloader()

# Misconfig when neither test_step or test_end is implemented
with pytest.raises(MisconfigurationException):
model = LightningTestModelBase(hparams)
Trainer().test(model)

# No exceptions when one or both of test_step or test_end are implemented
model = TestModelNoStep(hparams)
Trainer().test(model)

model = TestModelNoEnd(hparams)
Trainer().test(model)

model = LightningTestModel(hparams)
Trainer().test(model)


# if __name__ == '__main__':
# pytest.main([__file__])

0 comments on commit 6b667b1

Please sign in to comment.