Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping callback #2151

Closed
adeboissiere opened this issue Jun 11, 2020 · 10 comments · Fixed by #2391
Closed

Early stopping callback #2151

adeboissiere opened this issue Jun 11, 2020 · 10 comments · Fixed by #2391
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@adeboissiere
Copy link

🐛 Bug

Early stopping does not have the desired effect when creating a custom callback. Even when creating a custom callback with the default values, the training will stop before the early stopping before the conditions are met.

To Reproduce

  1. Create callback
early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='min',
        strict=True
    )
  1. Create trainer
trainer = Trainer.from_argparse_args(Namespace(**dict(train_config)), early_stop_callback=early_stop_callback)
  1. Train
trainer.fit(model)

Here are the validation steps in the model:

    def validation_step(self, batch, batch_idx):
        batch, y = batch
        y_hat = self(batch)

        loss = F.cross_entropy(y_hat, y.long())
        labels_hat = torch.argmax(y_hat, dim=1)
        n_correct_pred = torch.sum(y == labels_hat).item()

        return {'val_loss': loss, "n_correct_pred": n_correct_pred, "n_pred": len(y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)
        tensorboard_logs = {'val_loss': avg_loss, 'val_acc': val_acc}

        return {'val_loss': avg_loss, 'log': tensorboard_logs}

Expected behavior

In my case, training stops after 2 epochs, whether the validation loss increases or not. The callback behavior should be the same as the default. When I don't pass a custom callback, it works fine. I'm probably doing something wrong.

Environment

  • PyTorch Version : 1.4.0+cu100
  • OS: Ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.6.9
  • CUDA/cuDNN version: 10.0.130/7.6.4
  • GPU models and configuration: GeForce GTX 860M

Thanks'!

@adeboissiere adeboissiere added the help wanted Open to be worked on label Jun 11, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Jun 11, 2020

@jeremyjordan mind have a look ^^

@Borda Borda added the bug Something isn't working label Jun 11, 2020
@alekseynp
Copy link

What version of pytorch-lightning are you using?
I see a recent commit
479ab49
Only available in master and 0.8.0rc1
https://github.com/PyTorchLightning/pytorch-lightning/tree/0.8.0rc1

@alekseynp
Copy link

For the record, I arrived at this issue because in version 0.7.6 I observe early stopping not behaving properly.

@tpanum
Copy link

tpanum commented Jun 12, 2020

I have been puzzled about the issue of a CustomCallback never stopping in max mode (despite it clearly should've, considering the tensorboard logs). I am sitting on 0.7.6. Thanks for pointing it out, @alekseynp. I'll see if an upgrade can fix it.

@DavidRuhe
Copy link

DavidRuhe commented Jun 14, 2020

Hi. I'm not sure if I should create a separate issue, but there is a very confusing bug regarding early stopping (still in the current master branch). The documentation states

By default early stopping will be enabled if ‘val_loss’ is found in validation_epoch_end()’s return dict. Otherwise training will proceed with early stopping disabled.

However, this is not true due to the following bug.
In callback_config.py we see the following code.

def configure_early_stopping(self, early_stop_callback):
        if early_stop_callback is True or None:
            self.early_stop_callback = EarlyStopping(
                monitor='val_loss',
                patience=3,
                strict=True,
                verbose=True,
                mode='min'
            )
            self.enable_early_stop = True
        elif not early_stop_callback:
            self.early_stop_callback = None
            self.enable_early_stop = False
        else:
            self.early_stop_callback = early_stop_callback
            self.enable_early_stop = True

Unless I'm misunderstanding something, the the behaviour as the documentation says it should be
if early_stop_callback is True or early_stop_callback is None: and the default argument should be put to None:

early_stop_callback: Optional[Union[EarlyStopping, bool]] = None,

In any case, the 'or None' clause will never be True and therefore is redundant as of now.

David

@adeboissiere
Copy link
Author

What version of pytorch-lightning are you using?
I see a recent commit
479ab49
Only available in master and 0.8.0rc1
https://github.com/PyTorchLightning/pytorch-lightning/tree/0.8.0rc1

Hi. I'm using version 0.7.6.

@jeremyjordan
Copy link
Contributor

jeremyjordan commented Jun 16, 2020

There are a number of issues with early stopping, I have a PR (#1504) out to fix them. I have added a new test to cover your case @adeboissiere.

@brucemuller
Copy link

I'm also having issues. It seems to be only using the default values regardless and also the default values do not cause any early stopping to occur.

@williamFalcon
Copy link
Contributor

Currently being worked on in #1504

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants