Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataModule] prepare_data() and setup() not called #2742

Closed
remisphere opened this issue Jul 28, 2020 · 3 comments · Fixed by #2755
Closed

[DataModule] prepare_data() and setup() not called #2742

remisphere opened this issue Jul 28, 2020 · 3 comments · Fixed by #2755
Assignees
Labels
bug Something isn't working data handling Generic data-related topic help wanted Open to be worked on

Comments

@remisphere
Copy link

🐛 Bug

It seems that when using DataModule to separate training logic and data loading,
of the five methods that should be called that are
prepare_data(), setup(), train_dataloader(), val_dataloader() and test_dataloader(),
only the last three are actually used, witch is problematic since the datasets used by the data-loaders should be assigned in the setup().

To Reproduce

Steps to reproduce the behavior:
Run this:

Code sample

import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer import Trainer
from torch.nn import L1Loss, Linear
from torch.optim import SGD
from torch.utils.data import DataLoader


class MyDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()

    def prepare_data(self):
        print('in prepare_data, '
              'this should be called before train_dataloader() but is not.')

    def setup(self, stage):
        print('in setup, '
              'this should be called before train_dataloader() but is not.')
        self.train_dataset = 'whatever'

    def train_dataloader(self):
        print('in train_dataloader')
        return DataLoader(self.train_dataset)


class MyLightningModule(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = Linear(1, 1)
        self.loss_function = L1Loss()

    def forward(self, x):
        return self.layer(x)

    def configure_optimizers(self):
        return SGD(self.parameters(), lr=0.01)

    def training_step(self, batch, batch_idx):
        print("you won't even get here")
        raise NotImplementedError


data_module = MyDataModule()
model = MyLightningModule()
trainer = Trainer(gpus=1)
trainer.fit(model, data_module)

this gives AttributeError: 'MyDataModule' object has no attribute 'train_dataset'.

Expected behavior

When entering train_dataloader(), prepare_data() and setup() should already have been executed, and thus the train_dataset attribute should exist.

Additional context

IMHO, it comes from here

Environment

  • CUDA:
    • GPU:
      • GeForce RTX 2080 Ti
      • GeForce RTX 2080 Ti
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.1
    • pyTorch_debug: False
    • pyTorch_version: 1.5.1+cu101
    • pytorch-lightning: 0.9.0rc2
    • tensorboard: 2.3.0
    • tqdm: 4.48.0
  • System:
@remisphere remisphere added bug Something isn't working help wanted Open to be worked on labels Jul 28, 2020
@nateraw
Copy link
Contributor

nateraw commented Jul 28, 2020

  1. you're not specifying the datamodule kwarg in trainer.fit() - your last line should look like this: trainer.fit(model, datamodule=data_module)

  2. In this first iteration of LightningDataModule, you have to call setup and prepare_data manually for the datamodule instance. We have it set up this way so if you don't want to use Lightning, you can use your datamodule's loaders with pure Pytorch. I thought of having them called implicitly in the PR, but ended up landing on this for now. I'm not sure if users would always want these to run implicitly.

TL;DR: you can update your code to look like this:

# Init a datamodule
dm = MyDataModule()

# Manually call prepare_data and setup. You could put this at end of __init__ if you want
dm.prepare_data()
dm.setup()

model = MyLightningModule()
trainer = Trainer(gpus=1)
trainer.fit(model, datamodule=dm)

That being said, we're open to any ideas on making this more intuitive, so feel free to throw out some alternatives. 😄

@remisphere
Copy link
Author

remisphere commented Jul 28, 2020

  1. is not true in 0.9.0rc2: a data module as second positional argument is taken care of here.

  2. I don't have a global enough view to know what other users might want, so if it is a feature i'm fine with it.
    I just saw that the manual call was in the docs, my bad for not looking far enough.

Anyway thank you for the clear answer ^^

@nateraw
Copy link
Contributor

nateraw commented Jul 29, 2020

@remisphere I totally didn't notice! You were completely right on the dm arg. things move fast haha.

Reopening actually, as I think your intended use is more user friendly.

@nateraw nateraw reopened this Jul 29, 2020
@nateraw nateraw linked a pull request Jul 29, 2020 that will close this issue
7 tasks
@Borda Borda added the data handling Generic data-related topic label Jul 31, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants