Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test pass shouldn't require both test_step and test_end #909

Closed
MattPainter01 opened this issue Feb 21, 2020 · 2 comments · Fixed by #918 or #926
Closed

Test pass shouldn't require both test_step and test_end #909

MattPainter01 opened this issue Feb 21, 2020 · 2 comments · Fixed by #918 or #926
Assignees
Labels
bug Something isn't working good first issue Good for newcomers help wanted Open to be worked on

Comments

@MattPainter01
Copy link
Contributor

🐛 Bug

trainer.test(...) requires implementation of both test_step and test_end, but the warning says you only need to implement either or both.

https://github.com/PyTorchLightning/pytorch-lightning/blob/56dddf970825b1ad2b598c9b9b23a8a77add8964/pytorch_lightning/trainer/evaluation_loop.py#L291

To Reproduce

Run .test() on any LightningModule with only test_step or test_end implemented.

Code sample

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import pytorch_lightning as pl


class CoolSystem(pl.LightningModule):
    def __init__(self):
        super(CoolSystem, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        return {'loss': loss}

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        return DataLoader(MNIST('./', train=True, download=True), batch_size=32)

    @pl.data_loader
    def test_dataloader(self):
        return DataLoader(MNIST('./', train=False, download=True), batch_size=32)


model = CoolSystem()
trainer = pl.Trainer(max_epochs=2, val_percent_check=1)
trainer.test(model)

Expected behaviour

Test pass runs should run with either implemented or when at least test_step is.

@MattPainter01 MattPainter01 added bug Something isn't working help wanted Open to be worked on labels Feb 21, 2020
@github-actions
Copy link
Contributor

Hey, thanks for your contribution! Great first issue!

@MattPainter01 MattPainter01 removed the help wanted Open to be worked on label Feb 21, 2020
@Borda
Copy link
Member

Borda commented Feb 21, 2020

Great catch, would you send a PR? 🤖

@Borda Borda added good first issue Good for newcomers help wanted Open to be worked on need fix labels Feb 21, 2020
@MattPainter01 MattPainter01 self-assigned this Feb 21, 2020
williamFalcon added a commit that referenced this issue Feb 25, 2020
williamFalcon added a commit that referenced this issue Feb 25, 2020
williamFalcon added a commit that referenced this issue Feb 25, 2020
* added get dataloaders directly using a getter

* deleted decorator

* added prepare_data hook

* refactored dataloader init

* refactored dataloader init

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixes #909

* fixes #909

* bug fix

* Fixes #902
tullie pushed a commit to tullie/pytorch-lightning that referenced this issue Apr 3, 2020
* added get dataloaders directly using a getter

* deleted decorator

* added prepare_data hook

* refactored dataloader init

* refactored dataloader init

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixes Lightning-AI#909

* fixes Lightning-AI#909

* bug fix

* Fixes Lightning-AI#902
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers help wanted Open to be worked on
Projects
None yet
2 participants