Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer.test() in combination with resume_from_checkpoint is broken #5091

Closed
ORippler opened this issue Dec 11, 2020 · 4 comments · Fixed by #5161
Closed

Trainer.test() in combination with resume_from_checkpoint is broken #5091

ORippler opened this issue Dec 11, 2020 · 4 comments · Fixed by #5161
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Milestone

Comments

@ORippler
Copy link
Contributor

ORippler commented Dec 11, 2020

🐛 Bug

When passing resume_from_checkpoint to Trainer, and then training (e.g. call to trainer.fit()), the state used for trainer.test() is always the checkpoint initially given to resume_from_checkpoint, and never the newer, better one.

trainer = Trainer(resume_from_checkpoint="path_to_ckpt") # pass ckpt to Trainer for resuming
trainer.fit() # do some fine-tuning/resume training
trainer.test() # should make use of "best" checkpoint, however uses ckpt passed to resume_from_checkpoint

Please reproduce using the BoringModel and post here

https://colab.research.google.com/drive/1ABXnUP10QUqHeUQmFy-FX26cV2w1JILA?usp=sharing

Expected behavior

After fine-tuning, the best model state is looked up internally as introduced by #2190 before running on the test dataset.

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu101
    • pytorch-lightning: 1.1.0
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: 1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

A hotfix is to manually set trainer.resume_from_checkpoint = None between calls to trainer.fit() and trainer.test().

trainer = Trainer(resume_from_checkpoint="path_to_ckpt") # pass ckpt to Trainer for resuming
trainer.fit()
trainer.resume_from_checkpoint = None
trainer.test()

The cause behind the issue is that Trainer.test() is performed internally by calling to Trainer.fit() for all configurations.

Long term, the checkpoint passed by resume_from_checkpoint should most likely be consumed internally (i.e. reset to None) after the state is restored. Alternatively, one could make use of the Trainer.testing attribute to limit the utilization of Trainer.resume_from_checkpoint by CheckpointConnector to the training state only.

@ORippler ORippler added bug Something isn't working help wanted Open to be worked on labels Dec 11, 2020
@edenlightning
Copy link
Contributor

@awaelchli thoughts?

@edenlightning edenlightning added the checkpointing Related to checkpointing label Dec 11, 2020
@tchaton tchaton added priority: 0 High priority task with code labels Dec 14, 2020
@tchaton
Copy link
Contributor

tchaton commented Dec 16, 2020

Hey @ORippler,

Thanks for reporting the bug. Would you mind making the notebook public ?
Waiting for you to do so, I will try to reproduce the bug locally and update I am manage to.

Best regards,
Thomas Chaton.

@tchaton tchaton added the waiting on author Waiting on user action, correction, or update label Dec 16, 2020
@ORippler
Copy link
Contributor Author

ORippler commented Dec 16, 2020

@tchaton

Link should be updated.

Cheers!

@ananthsub
Copy link
Contributor

For future reference, #9405 and https://github.com/PyTorchLi.../pytorch-lightning/pull/10061 unified the checkpoint loading paths to avoid this confusion by deprecating resume_from_checkpoint from the Trainer constructor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants