Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EarlyStopping checkpointed state is lagging one epoch behind #1464

Closed
lizhitwo opened this issue Apr 12, 2020 · 15 comments · Fixed by #2391
Closed

EarlyStopping checkpointed state is lagging one epoch behind #1464

lizhitwo opened this issue Apr 12, 2020 · 15 comments · Fixed by #2391
Assignees
Labels
bug Something isn't working discussion In a discussion stage help wanted Open to be worked on
Milestone

Comments

@lizhitwo
Copy link

🐛 Bug

Currently EarlyStopping's state is updated after the checkpoint callback, so what is being saved here is last epoch's state.

To Reproduce

This is somewhat related to #1463 so I am going to use the same code.

Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

model = CoolSystem()

checkpoint_callback = ModelCheckpoint(
    filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
    save_top_k=-1,
    verbose=True,
    monitor='val_loss',
    mode='auto'
)

class EarlyStoppingPrinting(EarlyStopping):

    def on_train_start(self, trainer, pl_module):
        print('EarlyStoppingPrinting before on_train_start')
        print('self.wait = ', self.wait)
        super().on_train_start(trainer, pl_module)
        print('EarlyStoppingPrinting after on_train_start')
        print('self.wait = ', self.wait)

    def on_epoch_end(self, trainer, pl_module):
        ret = super().on_epoch_end(trainer, pl_module)
        if self.wait:
            print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
        return ret


early_stopping = EarlyStoppingPrinting(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='auto'
)

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=checkpoint_callback, 
                  early_stop_callback=early_stopping)

trainer.fit(model)

Let the model train until convergence. And then reload the model and see how it continues:

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=None, 
                  resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_7.ckpt',
                  early_stop_callback=early_stopping)
trainer.fit(model)

The early_stopping callback would print:

EarlyStoppingPrinting before on_train_start
self.wait =  4
...

and keeps training.

Expected behavior

The early_stopping callback should print:

EarlyStoppingPrinting before on_train_start
self.wait =  5
...

and should not be trained again at all since self.wait >= self.patience.

If the model is loaded from an interrupted save, then it should still train after resuming, but with corrected self.wait.

Environment

This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD

Additional context

Somewhat related to #1463.

@lizhitwo lizhitwo added bug Something isn't working help wanted Open to be worked on labels Apr 12, 2020
@jeremyjordan
Copy link
Contributor

jeremyjordan commented Apr 16, 2020

@lizhitwo thanks for this very detailed bug report! looking into it...

at a high level, i think that we should keep the concern of the callback state contained within the callback itself. we can follow the pytorch convention of having methods for state_dict() and load_state_dict(). the trainer can just call those methods rather than reaching in and saving individual attributes. (this more so addresses #1463 but i plan to fix both in a single PR)

@jeremyjordan
Copy link
Contributor

jeremyjordan commented Apr 16, 2020

@Borda do you know why we need on_train_start? why can't we just set the values in __init__.py?

edit: going to remove on_train_start

@jeremyjordan
Copy link
Contributor

jeremyjordan commented Apr 17, 2020

the core problem for this issue is the CheckpointCallback runs on on_validation_end which occurs before EarlyStoppingCallback runs during on_epoch_end. the checkpoint callback is not run again after early stopping halts training. the checkpoint includes a state dict of the early stopping values, and as @lizhitwo points out the last saved callback contains the early stopping state of the second to last epoch.

  • we could move the checkpoint callback to also run during on_epoch_end, but this might not always be desired (eg. if you run validation multiple times per epoch and want all checkpoints).

  • we could also just write the checkpoint callback to re-run at the end of an epoch, but not sure how we want to handle saving the k best models in this case

cc @PyTorchLightning/core-contributors any suggestions? i'm not a huge fan of either of these approaches

@jeremyjordan
Copy link
Contributor

I would like to propose that we do the following:

  • when we trigger early stopping, check to see if a checkpoint callback is being used
  • if so, create a new checkpoint that following the standard naming scheme, plus "_early_stopping" at the end (eg. ckpt_epoch_7_early_stopping.ckpt)

@lizhitwo would this be a suitable solution for you needs?

@lizhitwo
Copy link
Author

This would be correct only when the training ends normally. When the user hits Ctrl+C to interrupt, the previous checkpoints that are already saved are still lagging one epoch behind.

I would also advise against putting more undocumented checkpoint naming conventions into Lightning. I currently am already confused why Lightning overrides my checkpoint name unless I put e.g. {epoch} in them.

Question: is there a reason why early stopping callback can't be split into two, and the status update moved before checkpoint? You can update its state before checkpoint, and checkpoint, and then query its state to decide if early-stopping should be performed using e.g. a should_early_stop property or something.

@jeremyjordan
Copy link
Contributor

yeah you're right, we need a better solution here. i think we can also improve how the checkpoint naming is done but i'll leave that for a separate issue.

Question: is there a reason why early stopping callback can't be split into two, and the status update moved before checkpoint?

it's a good question. this is a bit challenging because our checkpoint callback is currently set up to run every time we iterate over the validation set. this can happen once per epoch, multiple times per epoch, or once every n epochs, depending on how the user has defined various arguments in their Trainer.

here's how i think we should proceed:

  • add a new argument to the early stopping callback which specifies whether we monitor a value from the training_step output or validation_step output (this is cleaner than guessing imo)
  • if we're monitoring the training_step, perform a check after each batch (eg. on_batch_end) which always runs before the checkpoint callback
  • if we're monitoring the validation_step, perform a check after running through the validation set (eg. on_validation_end) but before the checkpoint callback runs
  • switch from having early stopping return a value to instead set a Trainer attribute

cc @lizhitwo and @Borda want to weigh in if this is a good plan?

@Borda Borda added the discussion In a discussion stage label Apr 24, 2020
@Borda Borda added this to the 0.7.5 milestone Apr 24, 2020
@Borda
Copy link
Member

Borda commented Apr 24, 2020

I think that in some cases you want to monitor both train and valid...

switch from having early stopping return a value to instead set a Trainer attribute

do you mean a Trainer state like we discussed several times before?

I think that there is one more thing we shall think about and it specifies the "unit" for evaluation if the e.g. patience od epoch or batch/step

cc: @PyTorchLightning/core-contributors

@jeremyjordan
Copy link
Contributor

I think that in some cases you want to monitor both train and valid...

could you provide an example? in 90% of the cases the user is going to want to monitor something like val_loss. it seems very odd to condition the early stopping criterion on something which bridges both train and valid

do you mean a Trainer state like we discussed several times before?

i'm thinking of setting an attribute.

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L366

if self.enable_early_stop:
    if (met_min_epochs and met_min_steps) or self.fast_dev_run:
        should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
        # stop training
        stop = should_stop and met_min_epochs
        if stop:
            self.run_training_teardown()
            return

to

if self.should_stop:  # the checks against met_min_epochs and met_min_steps happen in the callback
    return 

@lizhitwo
Copy link
Author

lizhitwo commented Apr 24, 2020

I think this works, as long as

  • an attribute is written to either the trainer or the early stopping callback, either by one of the trainer's method or one of the callback's method
  • there is no checkpointing between "this attribute writing is done" and "the early stopping callback has the ability to update this value"
  • the attribute is preserved when checkpointing

then no matter when the checkpoint is done, the early-stopping's state is clean.

Although this would break compatibility with older code, so maybe set the attribute inside Trainer according to early stopping's return value, instead of in early stopping callback which may be overridden a lot?


I think monitoring both train and val is better left to users, since they need to specify how the criteria is computed anyway. They can choose to compute it during either val or train and add it to the log in one of them.

@lizhitwo
Copy link
Author

By the way, the newest version of Lightning has an additional bug: the early stopping is evaluated twice per epoch, here and here. You can run my code again to see that the trainer.wait decreases by 2 each epoch.

@jeremyjordan
Copy link
Contributor

yes that was a behavioral regression introduced by #1528 i will fix it as well, thanks for catching it! we clearly need better tests to spot these errors

@williamFalcon
Copy link
Contributor

@jeremyjordan submit asap so we can get it in 0.7.4? @lizhitwo thanks for catching that!

@Borda Borda modified the milestones: 0.7.5, 0.7.4 Apr 24, 2020
@williamFalcon
Copy link
Contributor

actually... this may not be a trivial fix

@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 24, 2020

in (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L459) we need to check it whenever we save weights. In this case, if this is true, then we need to stop training but make sure we do all the rest of the actions we need to call.

In (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L369) we actually want to exit.

But anyhow, @jeremyjordan if you figure this out let's get this into 0.7.4

@jeremyjordan
Copy link
Contributor

yes i'm hoping to get this all wrapped up this weekend, it's a bit of a tricky one

@Borda Borda added this to the 0.7.7 milestone May 15, 2020
@Borda Borda modified the milestones: 0.7.7, 0.8.0 May 26, 2020
@Borda Borda modified the milestones: 0.8.0, 0.9.0 Jun 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working discussion In a discussion stage help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants