Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataloader arg to Trainer.test() #1393

Closed
Anjum48 opened this issue Apr 6, 2020 · 8 comments · Fixed by #1434
Closed

Add dataloader arg to Trainer.test() #1393

Anjum48 opened this issue Apr 6, 2020 · 8 comments · Fixed by #1434
Labels
discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement priority: 0 High priority task
Milestone

Comments

@Anjum48
Copy link

Anjum48 commented Apr 6, 2020

🚀 Feature

It would be nice if you could use a model for inference using:
Trainer.test(model, test_dataloaders=test_loader)

Motivation

This will match the calling structure for Trainer.fit() and allow for test to be called on any dataset multiple times

Pitch

Here's a use case. After training a model using 5-fold cross-validation, you may want to stack the 5 checkpoints across multiple models, which will require a) out-of-fold (OOF) predictions and b) the 5 test predictions (which will be averaged). It would be cool if a & b could be generated as follows:

for f in folds:
    model1.load_from_checkpoint(f'path/to/model1_fold{f}.ckpt')
    trainer.test(model1,  test_dataloaders=valid_loader)
    trainer.test(model1,  test_dataloaders=test_loader)

    model2.load_from_checkpoint(f'path/to/model2_fold{f}.ckpt'))
    trainer.test(model2,  test_dataloaders=valid_loader)
    trainer.test(model2,  test_dataloaders=test_loader)

Alternatives

Maybe I'm misunderstanding how test works and there is an easier way? Or perhaps the best way to do this is to write an inference function as you would in pure PyTorch?

Additional context

@Anjum48 Anjum48 added feature Is an improvement or enhancement help wanted Open to be worked on labels Apr 6, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Apr 6, 2020

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Apr 8, 2020

I am in favour of adding this option, but first, lets see how it fits the API
@williamFalcon any strong suggestion against? cc: @PyTorchLightning/core-contributors

@Borda Borda added the discussion In a discussion stage label Apr 8, 2020
@williamFalcon
Copy link
Contributor

test is meant to ONLY operate on the test set. it’s meant to keep people from using the test set when they shouldn’t haha (ie: only right before publication or right before production use).

additions that i’m not sure align well

  1. Trainer.test as an instance method. Why wouldn’t you just init the trainer? otherwise you won’t be able to test on distributed environments or configure the things you need like apex, etc.

additions that are good

  1. allowing the test function to take in a dataset. this also aligns with how fit works.
  2. fit should also not take a test dataloader (not sure if it does now).
  3. current .test already uses your test dataloader defined in the lightningmodule. so the ONLY addition we’re talking about here is allowing test to ALSO take in a dataloader and use that one only.

@Ir1d
Copy link
Contributor

Ir1d commented Apr 9, 2020

btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.

@williamFalcon williamFalcon added priority: 0 High priority task let's do it! approved to implement labels Apr 9, 2020
@williamFalcon
Copy link
Contributor

Let's do this:

  1. Add a test_dataloader method to .test()
  2. remove the test_dataloader from .fit()?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Apr 9, 2020

btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.

@Ir1d Try this:
https://www.kaggle.com/rohitgr/quest-bert

williamFalcon pushed a commit that referenced this issue Apr 10, 2020
* Add test_dataloaders to test method

* Remove test_dataloaders from .fit()

* Fix code comment

* Fix tests

* Add test_dataloaders to test method (#1393)

* Fix failing tests

* Update docs (#1393)
@Borda Borda added this to the 0.7.3 milestone Apr 10, 2020
tullie pushed a commit to tullie/pytorch-lightning that referenced this issue Jun 7, 2020
* Add test_dataloaders to test method

* Remove test_dataloaders from .fit()

* Fix code comment

* Fix tests

* Add test_dataloaders to test method (Lightning-AI#1393)

* Fix failing tests

* Update docs (Lightning-AI#1393)
@ArthDh
Copy link

ArthDh commented Jun 8, 2020

https://www.kaggle.com/rohitgr/quest-bert

Hey @rohitgr7! The link seems to be broken, could you point to any other resource? Thanks!

@rohitgr7
Copy link
Contributor

rohitgr7 commented Jun 8, 2020

@ArthDh Try this one: https://www.kaggle.com/rohitgr/roberta-with-pytorch-lightning-train-test-lb-0-710

@Borda Borda modified the milestones: 0.7.3, v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants