Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run full validation epoch before training #1715

Closed
simonepri opened this issue May 3, 2020 · 19 comments · Fixed by #2246
Closed

Run full validation epoch before training #1715

simonepri opened this issue May 3, 2020 · 19 comments · Fixed by #2246
Assignees
Labels
question Further information is requested

Comments

@simonepri
Copy link

❓ Questions and Help

What is your question?

How can I manually trigger a full validation step before training?

I want to compute and log the validation metrics before I start the training (ideally also updating the progress bar dictionary).

The reason why I want to do this is that I am fine-tuning a pre-trained model, and I want to check the performances before training.

@simonepri simonepri added the question Further information is requested label May 3, 2020
@awaelchli
Copy link
Member

To evaluate the performance of an existing model in your case, it is best practice to implement the test methods in the lightning module and then invoke the Trainer.test(). So I imagine your workflow will roughly be

model = YourLightningModule.load_from_checkpoint(...)

trainer = Trainer(options for testing)
Trainer.test()  # in the future, this will return results directly
# for now:
metrics = trainer.progress_bar_metrics  # for example to print them

# now training
model = YourLightningModule.load_from_checkpoint(...) 
trainer = Trainer(args for training)
trainer.fit()

There is a reason why the validation is tied to the training and cannot be run so easily from the outside. The validation is conceptually not the same as the test and does not reflect the true performance of a model, because we do things like early stopping based on validation loss etc.

@simonepri
Copy link
Author

simonepri commented May 4, 2020

Hi @awaelchli
I understand that I can run the test before training**, but that's a bit different from what I am trying to achieve.

Currently, the Trainer class accepts the num_sanity_val_steps which allows users to define how many validation steps to execute before running.
It would be great to be able to set this to a special value, say -1, to tell the Trainer to run a full validation epoch before training instead of just some fixed number of steps.
Then inside on_sanity_check_end_called people would be able to use metrics_callback to do whatever they want with the validation results***.

**: Currently you can't directly call test() on a pretrained model unless you manually call prepare_data() before the test() call, see #1562.

***: Currently the trainer.metrics_callback inside the on_sanity_check_end_called is not populated with the validation metrics.

@oyj0594
Copy link

oyj0594 commented May 21, 2020

I'm agree with @simonepri. it's good to run full validation_sanity_check before training.
If we can know full num_sanity_val_steps before declaring trainer, it gonna be simple, however i don't think it's possible.
There is not option for such feature for now, isn't it?

@awaelchli
Copy link
Member

yes maybe we could think about the option num_sanity_val_steps=-1
@PyTorchLightning/core-contributors

@dvirginz
Copy link

I agree. It's good practice to run the validation before the train, and I'll surely use it!

@dvirginz
Copy link

dvirginz commented Jun 11, 2020

It's pretty straight forward, I guess I'm not the first to write this:)

class run_validation_on_start(Callback):
    def __init__(self):
        pass

    def on_train_start(self, trainer: Trainer, pl_module):
        return trainer.run_evaluation(test_mode=False)

@williamFalcon
Copy link
Contributor

@awaelchli yes, perfect. let’s do it that way.

num_sanity_val_steps=-1

PR for 0.8.0?

@awaelchli
Copy link
Member

happy to make it but it would be easier to do after #1920.

@nmerty
Copy link

nmerty commented Jun 25, 2021

These changes still do not cover the case where we want to run full validation and also log the val result before training, right? So the case with num_sanity_val_steps=-1 and logging is enabled.

@awaelchli
Copy link
Member

@nmerty the sanity check is to make sure your code is functional and not crashing. the logging code also runs, it's just that the final values are not getting sent to the logger. we don't actually want to log anything, because that would lead to an unwanted offset in all figures of the logged keys that would follow in the normal epochs.

If you want to validate the model in advance you can do

trainer.validate(model)
trainer.fit(model)

@odedbd
Copy link

odedbd commented Aug 2, 2022

I am using the solution suggested by @awaelchli in my training script, passing a DataModule as the data source for both validation and training. The pre-training validation score is logged, but I noticed that the validation DataLoader (and validation workers) are being created twice when used in this manner.

In my use case, the creation of workers is a relatively costly process, and I wonder if there is a way to log the validation score before training, while using DataModule, without causing the validation DataLoader to be intialized twice.

@awaelchli
Copy link
Member

That is probably because the sanity check validation also runs when calling fit, leading to the val loader being reset an additional time. Since you are doing this:

trainer.validate(model)
trainer.fit(model)

I would suggest setting Trainer(num_sanity_val_steps=0) because you already make the explicit validate call. Let me know, does that reduce the overhead you see?

@odedbd
Copy link

odedbd commented Aug 4, 2022

Thanks for the suggestion, @awaelchli. I tested and this does not seem to be the cause. I see the same behavior with num_sanity_val_steps=0. It seems as though the Trainer instance does not persist the validation dataloader when running trainer.validate or ignores such a persisted dataloader when running trainer.fit subsequently.

@awaelchli
Copy link
Member

Lightning can't make the assumption that somehow the dataloader used during the validate call is going to be the same as in fit. In general, the trainer doesn't have any information whether it should reuse dataloaders from a previous stage or not. Each stage is separated, and this makes sense in most cases. If you really want to load the val dataloader only once, ever, I suggest you implement that directly in the datamodule like so:

def val_dataloader(self):
    if self._loaded_val_dataloader is None:
        self._loaded_val_dataloader = ... # here the expensive way you create the loader
        return self._loaded_val_dataloader
    else:
        # if we already created it in a previous call, reuse it. 
        return self._loaded_val_dataloader

@odedbd
Copy link

odedbd commented Aug 4, 2022

This makes a lot of sense, I will try it out. Could doing so be unsafe in the context of distributed data loading? I am currently using a single GPU, so it is of no immediate concern, just wondering whether I should be aware of a possible issue in case I scale up at a later time.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Aug 4, 2022

@odedbd it won't be a problem. Since during each .fit, validate call, all the loaded dataloaders are detached from trainer so they are just reloaded again.

@odedbd
Copy link

odedbd commented Aug 7, 2022

I tested the suggested method and the validation DataLoader is still recreated and new workers are initialized. I tried debugging the pytorch lightning code but wasn't able to pinpoint where the dataloader is reset. I did verify that the val_dataloader method was called twice, in the first time the _loaded_val_dataloader was None and in the second time it was the saved dataloader, which the method returned.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Aug 7, 2022

if it returned the same dataloader, you are saying it is reinitializing the workers? can you share a reproducible script to verify this behavior?

@odedbd
Copy link

odedbd commented Aug 7, 2022

Yes, that's what I think I am seeing. Below is a repro script based on one of PL examples. You'll notice the worker init function of the val dataset is called a second time when the fit starts.


# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST backbone image classifier example.

To run: python backbone_image_classifier.py --trainer.max_epochs=50
"""
import logging

from os import path
from typing import Optional

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms

DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

def wif(worker_id):
    print(f'worker_id {worker_id} init')

class Backbone(torch.nn.Module):
    """
    Backbone(
      (l1): Linear(...)
      (l2): Linear(...)
    )
    """

    def __init__(self, hidden_dim=128):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
        self.l2 = torch.nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x


class LitClassifier(LightningModule):
    """
    LitClassifier(
      (backbone): ...
    )
    """

    def __init__(self, backbone: Optional[Backbone] = None, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters(ignore=["backbone"])
        if backbone is None:
            backbone = Backbone()
        self.backbone = backbone

    def forward(self, x):
        # use forward for inference/predictions
        embedding = self.backbone(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("valid_loss", loss, on_step=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss)

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        return self(x)

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


class MyDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
        self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
        self.batch_size = batch_size
        self._val_dataloader = None

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):


        print('val_dataloader: in')
        if self._val_dataloader is None:
            print('val_dataloader: self._val_dataloader is None')
            self._val_dataloader = DataLoader(self.mnist_val, batch_size=self.batch_size, worker_init_fn=wif,
                                              num_workers=1, persistent_workers=True)
        print('val_dataloader: out')
        return self._val_dataloader

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)


def cli_main():
    cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
    cli.trainer.validate(cli.model, datamodule=cli.datamodule)
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
    predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
    print(predictions[0])


if __name__ == "__main__":
    cli_lightning_logo()
    cli_main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants