Skip to content

Commit

Permalink
🚧 .
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 22, 2020
1 parent ea314f4 commit 081b2a8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def prepare_data(self):
cache_imagenet()
"""

@abstractmethod
def setup(self, stage):
"""
Use this to make assignments to the class.
"""

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from abc import ABC, abstractmethod

from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule


Expand All @@ -15,7 +16,9 @@ def is_function_implemented(self, f_name, model=None):
def is_overridden(self, method_name: str, model: LightningModule = None) -> bool:
if model is None:
model = self.get_model()
super_object = LightningModule
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule

# assert model, 'no model passes'

Expand Down
29 changes: 27 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.loggers import LightningLoggerBase
Expand Down Expand Up @@ -890,7 +891,8 @@ def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None
):
r"""
Runs the full optimization routine.
Expand Down Expand Up @@ -939,6 +941,7 @@ def fit(

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule)

# check that model is configured correctly
self.check_model_configuration(model)
Expand Down Expand Up @@ -1111,6 +1114,24 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def __attach_datamodule(self, model, datamodule=None):

# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:
if self.is_overridden('setup', datamodule):
model.setup = datamodule.setup
if self.is_overridden('prepare_data', datamodule):
model.prepare_data = datamodule.prepare_data
if self.is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if self.is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if self.is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader

def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Expand Down Expand Up @@ -1241,7 +1262,8 @@ def test(
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best',
verbose: bool = True
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None
):
r"""
Expand Down Expand Up @@ -1305,6 +1327,9 @@ def test(
if self.global_rank != 0:
return

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.__attach_datamodule(model or self.get_model(), datamodule)

self.setup('test')

if model is not None:
Expand Down

0 comments on commit 081b2a8

Please sign in to comment.