Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass all running stages to DataModule.setup #5658

Closed
carmocca opened this issue Jan 26, 2021 · 5 comments · Fixed by #6386
Closed

Pass all running stages to DataModule.setup #5658

carmocca opened this issue Jan 26, 2021 · 5 comments · Fixed by #6386
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on refactor

Comments

@carmocca
Copy link
Contributor

🚀 Feature

Currently. DataModule.setup is only called with stages fit or test. But we have several more:

Stages:

https://github.com/PyTorchLightning/pytorch-lightning/blob/5f3372871a333c3229968f1af1b10a925d7ec3ec/pytorch_lightning/trainer/states.py#L39-L49

Note that it's a bit tricky because fit is not a RunningStage. It indicates train or eval

Motivation

Allows having custom logic for each stage

Pitch

def setup(stage: Optional[str] = None):
    assert stage in list(RunningStage)
    ...

Additional context

We are passing 'test' when predicting as seen in #5579
https://github.com/PyTorchLightning/pytorch-lightning/blob/9137b16068fe03e6db8df548235363e5f5476aac/pytorch_lightning/trainer/trainer.py#L909

@carmocca carmocca added feature Is an improvement or enhancement help wanted Open to be worked on refactor labels Jan 26, 2021
@rohitgr7
Copy link
Contributor

I don't think the setup is not called again during the evaluation(eval).

@carmocca
Copy link
Contributor Author

Not currently but we will want to with trainer.validate()

@rohitgr7
Copy link
Contributor

true. forgot about that one.

@leifdenby
Copy link

This would be great to have! I just got bitten by this when trying to call model.predict(datamodule=...) because .setup(stage='predict') isn't called there's nothing for my datamodule return. (Here's the notebook I created based on "The BoringModule" in case that is useful: https://colab.research.google.com/drive/1-D7EcIDMeONje2aIr6nEfDnqhibj12OB?usp=sharing). I use a datamodule because that allows me to encapsulate the transforms applied to my data and use them both during training, test and prediction.

@rohitgr7
Copy link
Contributor

def setup(self, stage):
    if self.running_stage == 'predict':
        ....

@leifdenby just a workaround for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on refactor
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants