Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate scheduler's epoch off by one when resuming from checkpoint #1772

Closed
lizhitwo opened this issue May 10, 2020 · 3 comments
Closed
Labels
bug Something isn't working duplicate This issue or pull request already exists help wanted Open to be worked on

Comments

@lizhitwo
Copy link

lizhitwo commented May 10, 2020

🐛 Bug

Currently lr_scheduler's state is updated after the checkpoint callback, so what is being saved here is last epoch's state.

Note: I think this has the same fix as #1464, but I'm posting it here because (1) I got rekt by this again, (2) in case it's not the same bug, and (3) #1464 is not fixed.

To Reproduce

Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
        return [optimizer], [torch.optim.lr_scheduler.MultiStepLR(optimizer, [100], gamma=0.1)]

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

model = CoolSystem()

checkpoint_callback = ModelCheckpoint(
    filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
    save_top_k=-1,
    verbose=True,
    monitor='val_loss',
    mode='auto'
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='auto'
)

class PrintingCallback(pl.Callback):
    def on_epoch_start(self, trainer, pl_module):
        print('Scheduler epoch %d' % trainer.lr_schedulers[0]['scheduler'].last_epoch)
        print('Trainer epoch %d' % trainer.current_epoch)
        print('-'*80)

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=checkpoint_callback, 
                  early_stop_callback=early_stopping,
                  callbacks=[PrintingCallback()])

trainer.fit(model)

Let the model train until convergence. And then reload a saved model and see how it continues:

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=None, 
                  resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_2.ckpt',
                  early_stop_callback=early_stopping,
                  callbacks=[PrintingCallback()])
trainer.fit(model)

The PrintingCallback would print:

Scheduler epoch 2
Trainer epoch 3
--------------------------------------------------------------------------------

Scheduler epoch 3
Trainer epoch 4
--------------------------------------------------------------------------------
...

and so on.

Expected behavior

The PrintingCallback should print:

Scheduler epoch 3
Trainer epoch 3
--------------------------------------------------------------------------------

Scheduler epoch 4
Trainer epoch 4
--------------------------------------------------------------------------------
...

Environment

This is ran on Google colab.
https://colab.research.google.com/drive/1pkCSMaApyjH40jwrdl4aQLVYjnGP3JzD?usp=sharing

Additional context

Related to #1463 and #1464.

@lizhitwo lizhitwo added bug Something isn't working help wanted Open to be worked on labels May 10, 2020
@Borda
Copy link
Member

Borda commented May 11, 2020

@SkafteNicki mind have a look ^^

@SkafteNicki
Copy link
Member

After looking a bit on this, I think it will automatically be solved by #1504, since it is the exact same problem as with early stopping.

@edenlightning edenlightning added the duplicate This issue or pull request already exists label Jun 8, 2020
@edenlightning
Copy link
Contributor

Closing this, since the fix is tracked in #1504.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

4 participants