Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only invoke setup() once, not in both trainer.fit() and trainer.test() - #2620 follow up #9865

Closed
AlexHarn opened this issue Oct 7, 2021 · 2 comments
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@AlexHarn
Copy link

AlexHarn commented Oct 7, 2021

🐛 Bug

It seems like the exact bug described in #2620 is back: When calling trainer.test() after trainer.fit(), LightningDataModule.setup() is called twice, which is especially problematic when using random_split in setup() because it leads to training samples leaking into the test set (in the worst case of extremely large training data compared to test data, the test data will most likely consist exclusively of training samples).

To Reproduce

import os

import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        pass

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        print("\nSetting up!!!\n")
        dataset = RandomDataset(32, 96)
        self.data_train, self.data_val, self.data_test = random_split(
            dataset, [32, 32, 32]
        )

    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=2)

    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=2)

    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    datamodule = MyDataModule()

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.fit(model=model, datamodule=datamodule)
    trainer.test(model=model, datamodule=datamodule)


if __name__ == "__main__":
    run()

Expected behavior

I am expecting MyDataModule.setup() to be called once, but it is called twice (easily verified by the print statements).

Environment

  • CUDA:
    - GPU:
    - Tesla V100S-PCIE-32GB
    - Tesla V100S-PCIE-32GB
    - Tesla V100S-PCIE-32GB
    - Tesla V100S-PCIE-32GB
    - available: True
    - version: 11.1
  • Packages:
    - numpy: 1.20.3
    - pyTorch_debug: False
    - pyTorch_version: 1.9.1
    - pytorch-lightning: 1.4.8
    - tqdm: 4.60.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.11
    - version: Proposal for help #1 SMP Wed Jul 21 11:57:15 UTC 2021

Additional context

Maybe I misunderstand how prepare() and setup() are supposed to be used, especially in DDP mode, but the comment in the docs clearly states val/train/test split in setup(), which currently leads to the described problem of training data ending up in the test split when using a random split.

@AlexHarn AlexHarn added bug Something isn't working help wanted Open to be worked on labels Oct 7, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Oct 7, 2021

setup is currently called once per fit or test, not once overall. However, this behavior is changing Lightning v1.6, where setup will be called unconditionally across these methods. Why is it changing? #7301

Imagine the situation where you make edits to your datamodule in between calls to the trainer. this can lead to silent errors

dm = MyDataModule()
trainer = Trainer(...)
model = MyLightningModule()
trainer.fit(model, dm)
# update dm with new training dataset
dm.train_dataset = ...

trainer.fit(model, dm) # <-- datamodule.setup() isn't called again! 

So instead, you can write your datamodule so the setup hooks are idempotent

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.data_train = None
        sel.data_val = None
        sel.data_test = None

    def setup(self, stage):
        if self.data_train is None:
            self.data_train = ...
        if self.data_val is None:
            self.data_val = ...
       if self.data_test is None:
            self.data_test = ...

You can further optimize this by using the stage argument passed to setup to determine which attributes need to be initialized. For instance, you only need to initialize self.data_test if you're testing. you don't need to initialize during fitting. this could save you some extra memory and time.

#6420 is an issue to discuss whether we should have dedicated setup and prepare_data functions for each of the entry point functions to better isolate this.

@AlexHarn
Copy link
Author

AlexHarn commented Oct 8, 2021

I see, that makes a lot of sense to use it that way, thank you!

It just was not obvious to me from reading the docs/examples.

@AlexHarn AlexHarn closed this as completed Oct 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants