diff --git a/CHANGELOG.md b/CHANGELOG.md index 9069b116ffb3c..cfc5e2d89b4e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a tool for profiling training runs ([#782](https://github.com/PyTorchLightning/pytorch-lightning/pull/782)) - Improved flexibility for naming of TensorBoard logs, can now set `version` to a `str` to just save to that directory, and use `name=''` to prevent experiment-name directory ([#804](https://github.com/PyTorchLightning/pytorch-lightning/pull/804)) - Added option to specify `step` key when logging metrics ([#808](https://github.com/PyTorchLightning/pytorch-lightning/pull/808)) +- Added `train_dataloader`, `val_dataloader` and `test_dataloader` arguments to `Trainer.fit()`, for alternative data parsing ([#759]([https://github.com/PyTorchLightning/pytorch-lightning/pull/759])) ### Changed - Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752)) - Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767)) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 743aac656d18c..3448fda4d4864 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -10,7 +10,7 @@ def data_loader(fn): """ wraps(fn) attr_name = '_lazy_' + fn.__name__ - + @wraps(fn) def _get_data_loader(self): try: value = getattr(self, attr_name) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8a550e5df9119..8790946b6281c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -869,7 +869,6 @@ def tbptt_split_batch(self, batch, split_size): return splits @data_loader - @abstractmethod def train_dataloader(self): """Implement a PyTorch DataLoader @@ -895,8 +894,8 @@ def train_dataloader(self): ) return loader - """ + return None @data_loader def tng_dataloader(self): # todo: remove in v0.8.0 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ddcfd5e0c9629..6cd87b17e04d6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,7 +18,7 @@ parse_gpu_ids, determine_root_gpu_device ) - +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -840,17 +840,56 @@ def tng_tqdm_dic(self): # ----------------------------- # MODEL TRAINING # ----------------------------- - def fit(self, model): + def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader=None): r""" Runs the full optimization routine. + Args: + model (LightningModule): Model to fit. + + train_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch + DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. + + val_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single + Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloader method this will be skipped + + test_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single + Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloader method this will be skipped + Example:: + # Option 1, + # Define the train_dataloader(), test_dataloader() and val_dataloader() fxs + # in the lightningModule + # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY trainer = Trainer() model = LightningModule() + trainer.fit(model) + + # Option 2 + # in production cases we might want to pass different datasets to the same model + # Recommended for PRODUCTION SYSTEMS + train, val, test = DataLoader(...), DataLoader(...), DataLoader(...) + trainer = Trainer() + model = LightningModule() + trainer.fit(model, train_dataloader=train, + val_dataloader=val, test_dataloader=test) + + # Option 1 & 2 can be mixed, for example the training set can be + # defined as part of the model, and validation/test can then be + # feed to .fit() - trainer.fit() """ + + # Update the dataloader attributes of the model with the ones supplied here, + # if they are not already defined in model + _set_dataloader(model, train_dataloader, 'train_dataloader') + _set_dataloader(model, val_dataloader, 'val_dataloader') + _set_dataloader(model, test_dataloader, 'test_dataloader') + # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: task = int(os.environ['SLURM_LOCALID']) @@ -1048,3 +1087,49 @@ def test(self, model=None): self.fit(model) else: self.run_evaluation(test=True) + + +def _set_dataloader(model, dataloader, attribute): + r''' + Check dataloaders passed to .fit() method if they are pytorch DataLoader + objects and whether or not we should overright the corresponding dataloader + in the model + + Args: + model (LightningModule): The model to check + + dataloader: If a pytorch dataloader (or a list of pytorch dataloaders) + is passed, it will be incorporate into the model as model.attribute. + If attribute alreay exist it will warn the userpass. If not a + dataloader will throw an error + + attribute (str): The attribute to save the dataloader under + + ''' + # Check if attribute comes directly from base class or + # derived in user subclass + if LightningModule.__qualname__ in getattr(model, attribute).__qualname__: + # Val and test should be list of dataloaders + dataloader = dataloader if attribute == 'train_dataloader' or \ + (attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader] + + # Check we are given valid dataloaders + is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader) + is_dataloader_list = isinstance(dataloader, list) + if is_dataloader_list: + valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader) + if is_dataloader or is_dataloader_list and valid_loaders: + + # Overwrite abstract methods + dl = lambda: dataloader + dl.__name__ = attribute + setattr(model, attribute, dl) + + elif dataloader and dataloader != [None]: + raise ValueError(f'`{attribute}` needs to be an instance of ' + '`torch.utils.data.DataLoader` or a list of ' + 'DataLoaders, instead got %r`' % dataloader) + + elif dataloader: # if default (None) is passed, do not warn the user + warnings.warn(f'Model has predefined `{attribute}`,' + f' will skip `{attribute}={dataloader}` passed to fit method.') diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 5da8521e0f618..3e6424bd4982a 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -2,7 +2,7 @@ import torch -from .base import LightningTestModelBase +from .base import LightningTestModelBase, LightningTestModelBaseWithoutDataloader from .mixins import ( LightningValidationStepMixin, LightningValidationMixin, diff --git a/tests/models/base.py b/tests/models/base.py index 949a39ef26ca2..d33f0118dfd05 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -36,7 +36,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, self.targets = self.targets[:num_samples] -class LightningTestModelBase(LightningModule): +class TestModelBase(LightningModule): """ Base LightningModule for testing. Implements only the required interface @@ -48,7 +48,7 @@ def __init__(self, hparams, force_remove_distributed_sampler=False): :param hparams: """ # init superclass - super(LightningTestModelBase, self).__init__() + super(TestModelBase, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size @@ -178,10 +178,6 @@ def _dataloader(self, train): return loader - @data_loader - def train_dataloader(self): - return self._dataloader(train=True) - @staticmethod def add_model_specific_args(parent_parser, root_dir): # pragma: no cover """ @@ -218,3 +214,15 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover options=[32, 64, 128, 256], tunable=False, help='batch size will be divided over all gpus being used across all nodes') return parser + + +class LightningTestModelBase(TestModelBase): + """ with pre-defined train dataloader """ + @data_loader + def train_dataloader(self): + return self._dataloader(train=True) + + +class LightningTestModelBaseWithoutDataloader(TestModelBase): + """ without pre-defined train dataloader """ + pass diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2b064357ea7de..283b332aa0d7f 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -13,6 +13,7 @@ from tests.models import ( LightningTestModel, LightningTestModelBase, + LightningTestModelBaseWithoutDataloader, LightningValidationStepMixin, LightningValidationMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, @@ -449,6 +450,165 @@ class CurrentTestModel( trainer.test() +def test_train_dataloaders_passed_to_fit(tmpdir): + """ Verify that train dataloader can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel( + LightningTestModelBaseWithoutDataloader + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # only train passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True)) + results = trainer.fit(model, **fit_options) + + +def test_train_val_dataloaders_passed_to_fit(tmpdir): + """ Verify that train & val dataloader can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel( + LightningTestModelBaseWithoutDataloader + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # train, val passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloader=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) + assert len(trainer.get_val_dataloaders()) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + + +def test_all_dataloaders_passed_to_fit(tmpdir): + """ Verify train, val & test dataloader can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel( + LightningTestModelBaseWithoutDataloader + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # train, val and test passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloader=model._dataloader(train=False), + test_dataloader=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) + + assert len(trainer.get_val_dataloaders()) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + assert len(trainer.get_test_dataloaders()) == 1, \ + f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + + +def test_multiple_dataloaders_passed_to_fit(tmpdir): + """ Verify that multiple val & test dataloaders can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel( + LightningTestModelBaseWithoutDataloader + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # train, multiple val and multiple test passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloader=[model._dataloader(train=False), + model._dataloader(train=False)], + test_dataloader=[model._dataloader(train=False), + model._dataloader(train=False)]) + results = trainer.fit(model, **fit_options) + + assert len(trainer.get_val_dataloaders()) == 2, \ + f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + assert len(trainer.get_test_dataloaders()) == 2, \ + f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + + +def test_mixing_of_dataloader_options(tmpdir): + """Verify that dataloaders can be passed to fit""" + tutils.reset_seed() + + class CurrentTestModel( + LightningTestModelBase + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + fit_options = dict(val_dataloader=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) + + # fit model + trainer = Trainer(**trainer_options) + fit_options = dict(val_dataloader=model._dataloader(train=False), + test_dataloader=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) + assert len(trainer.get_val_dataloaders()) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + assert len(trainer.get_test_dataloaders()) == 1, \ + f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + + def _init_steps_model(): """private method for initializing a model with 5% train epochs""" tutils.reset_seed() @@ -533,5 +693,6 @@ def test_trainer_min_steps_and_epochs(tmpdir): assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \ trainer.current_epoch > 0, "Model did not train for at least min_steps" + # if __name__ == '__main__': # pytest.main([__file__])