Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpoint not picking up metrics logged from lightning module #3811

Closed
ananthsub opened this issue Oct 3, 2020 · 1 comment · Fixed by #3812
Closed

ModelCheckpoint not picking up metrics logged from lightning module #3811

ananthsub opened this issue Oct 3, 2020 · 1 comment · Fixed by #3812
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@ananthsub
Copy link
Contributor

🐛 Bug

The Model Checkpoint raises a misconfiguration error because metrics logged from validation epoch end are mysteriously unavailable to the callback

To Reproduce

from typing import Optional
import torch
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data.dataset import Dataset
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len
class TestModule(LightningModule):
    def __init__(self, epoch_min_loss_override: Optional[int] = None):
        """LightningModule for testing purposes
        Args:
            epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
                validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.epoch_min_loss_override = epoch_min_loss_override
    def forward(self, x):
        return self.layer(x)
    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
    def training_step(self, batch, batch_idx):
        output = self.forward(batch)
        loss = self.loss(batch, output)
        return {"output": output, "loss": loss, "checkpoint_on": loss}
    def validation_step(self, batch, batch_idx):
        output = self.forward(batch)
        loss = self.loss(batch, output)
        return {"output": output, "loss": loss, "checkpoint_on": loss}
    def test_step(self, batch, batch_idx):
        output = self.forward(batch)
        loss = self.loss(batch, output)
        return {"output": output, "loss": loss}
    def training_epoch_end(self, outputs) -> None:
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("avg_loss", avg_loss)
    def validation_epoch_end(self, outputs) -> None:
        avg_val_loss = torch.stack(
            [torch.randn(1, requires_grad=True) for _ in outputs]
        ).mean()
        # For testing purposes allow a nominated epoch to have a low loss
        if self.current_epoch == self.epoch_min_loss_override:
            avg_val_loss -= 1e10
        self.log("avg_val_loss", avg_val_loss)
        self.log("checkpoint_on", avg_val_loss)
    def test_epoch_end(self, outputs) -> None:
        avg_loss = torch.stack(
            [torch.randn(1, requires_grad=True) for _ in outputs]
        ).mean()
        self.log("val_loss", avg_loss)
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]
    def train_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64))
    def val_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64))
    def test_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64))
def train():
    checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_val_loss")
    trainer = Trainer(
        max_epochs=epoch_min_loss_override + 2,
        logger=False,
        checkpoint_callback=checkpoint_callback,
    )
    model = TestModule(epoch_min_loss_override=2)
    lightning_trainer.fit(model)

this is the error I see

raise MisconfigurationException(m)
pytorch_lightning.utilities.exceptions.MisconfigurationException: ModelCheckpoint(monitor='avg_val_loss') not found in the returned metrics: ['avg_loss']. HINT: Did you call self.log('avg_val_loss', tensor) in the LightningModule?

Full stacktrace:

    lightning_trainer.fit(model)                                                                                                        File "pytorch_lightning/trainer/trainer.py", line 442, in fit
    results = self.accelerator_backend.train()
  File "pytorch_lightning/accelerators/cpu_backend.py", line 47, in train
    results = self.train_or_test()
  File "pytorch_lightning/accelerators/base_backend.py", line 43, in train_or_test
    results = self.trainer.train()
  File "pytorch_lightning/trainer/trainer.py", line 489, in train
    self.train_loop.run_training_epoch()
  File "pytorch_lightning/trainer/training_loop.py", line 538, in run_training_epoch
    self.trainer.run_evaluation(test_mode=False)
  File "pytorch_lightning/trainer/trainer.py", line 611, in run_evaluation
    self.evaluation_loop.on_evaluation_end()
  File "pytorch_lightning/trainer/evaluation_loop.py", line 95, in on_evaluation_end
    self.trainer.call_hook('on_validation_end', *args, **kwargs)
  File "pytorch_lightning/trainer/trainer.py", line 800, in call_hook
    trainer_hook(*args, **kwargs)
  File "pytorch_lightning/trainer/callback_hook.py", line 177, in on_validation_end
    callback.on_validation_end(self, self.get_model())
  File "pytorch_lightning/callbacks/model_checkpoint.py", line 167, in on_validation_end
    self.save_checkpoint(trainer, pl_module)
  File "pytorch_lightning/callbacks/model_checkpoint.py", line 197, in save_checkpoint
    self._validate_monitor_key(trainer)
  File "pytorch_lightning/callbacks/model_checkpoint.py", line 440, in _validate_monitor_key
    raise MisconfigurationException(m)
pytorch_lightning.utilities.exceptions.MisconfigurationException: ModelCheckpoint(monitor='avg_val_loss') not found in the returned me
trics: ['avg_loss']. HINT: Did you call self.log('avg_val_loss', tensor) in the LightningModule?

Expected behavior

We can save the top-1 checkpoint with the monitor based on "avg_val_loss"

Environment

This is based on Lightning git revision 0c12065

Additional context

@ananthsub ananthsub added bug Something isn't working help wanted Open to be worked on labels Oct 3, 2020
@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 3, 2020

For me this is happening only when no logs are created in validation_step: #3000 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants