diff --git a/CHANGELOG.md b/CHANGELOG.md index 861e68102dd5e..21df3834bd368 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) ### Changed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 6b8d484ec5bc1..0dcd31b3db05e 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod +from typing import Union, List, Tuple, Callable import torch.distributed as dist from torch.utils.data import SequentialSampler, DataLoader from torch.utils.data.distributed import DistributedSampler +from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities.debugging import MisconfigurationException try: @@ -23,6 +25,15 @@ XLA_AVAILABLE = True +def _has_len(dataloader: DataLoader) -> bool: + try: + # try getting the length + _ = len(dataloader) + return True + except TypeError: + return False + + class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, @@ -35,27 +46,30 @@ class TrainerDataLoadingMixin(ABC): use_tpu: bool tpu_local_core_rank: int train_dataloader: DataLoader - num_training_batches: int + num_training_batches: Union[int, float] val_check_batch: ... - val_dataloaders: DataLoader - num_val_batches: int - test_dataloaders: DataLoader - num_test_batches: int + val_dataloaders: List[DataLoader] + num_val_batches: Union[int, float] + test_dataloaders: List[DataLoader] + num_test_batches: Union[int, float] + train_percent_check: float + val_percent_check: float + test_percent_check: float @abstractmethod def is_overriden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def _percent_range_check(self, name): + def _percent_range_check(self, name: str) -> None: value = getattr(self, name) - msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." - if name == "val_check_interval": - msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." + msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.' + if name == 'val_check_interval': + msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.' if not 0. <= value <= 1.: raise ValueError(msg) - def auto_add_sampler(self, dataloader, train): + def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: if self.use_ddp or self.use_ddp2 or self.use_tpu: dl_args = { 'dataset': dataloader.dataset, @@ -88,14 +102,14 @@ def auto_add_sampler(self, dataloader, train): dataloader = DataLoader(**dl_args) return dataloader - def reset_train_dataloader(self, model): - """ - Dataloaders are provided by the model - :param model: - :return: - """ + def reset_train_dataloader(self, model: LightningModule) -> None: + """Resets the train dataloader and initialises required variables + (number of batches, when to validate, etc.). - self.train_dataloader = self.request_data_loader(model.train_dataloader) + Args: + model: The current `LightningModule` + """ + self.train_dataloader = self.request_dataloader(model.train_dataloader) self.num_training_batches = 0 # automatically add samplers @@ -103,7 +117,7 @@ def reset_train_dataloader(self, model): self._percent_range_check('train_percent_check') - if self.is_infinite_dataloader(self.train_dataloader): + if not _has_len(self.train_dataloader): self.num_training_batches = float('inf') else: # try getting the length @@ -117,122 +131,119 @@ def reset_train_dataloader(self, model): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( - f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " - f"to the number of the training batches ({self.num_training_batches}). " - f"If you want to disable validation set `val_percent_check` to 0.0 instead.") + f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' + f'to the number of the training batches ({self.num_training_batches}). ' + 'If you want to disable validation set `val_percent_check` to 0.0 instead.') else: - if self.is_infinite_dataloader(self.train_dataloader): - m = ''' - When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader - does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)` - must be an int. An int k specifies checking validation every k training batches. - ''' - raise MisconfigurationException(m) + if not _has_len(self.train_dataloader): + raise MisconfigurationException( + 'When using an infinite DataLoader (e.g. with an IterableDataset or when ' + 'DataLoader does not implement `__len__`) for `train_dataloader`, ' + '`Trainer(val_check_interval)` must be an int. An int k specifies checking ' + 'validation every k training batches.') self._percent_range_check('val_check_interval') self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - def is_infinite_dataloader(self, dataloader): - try: - # try getting the length - _ = len(dataloader) - return False - except TypeError as e: - return True + def _reset_eval_dataloader(self, model: LightningModule, + mode: str) -> Tuple[int, List[DataLoader]]: + """Generic method to reset a dataloader for evaluation. - def reset_val_dataloader(self, model): - """ - Dataloaders are provided by the model - :param model: - :return: + Args: + model: The current `LightningModule` + mode: Either `'val'` or `'test'` + + Returns: + Tuple (num_batches, dataloaders) """ - if not self.is_overriden('validation_step'): - return + dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader')) - self.val_dataloaders = self.request_data_loader(model.val_dataloader) - if not isinstance(self.val_dataloaders, list): - self.val_dataloaders = [self.val_dataloaders] - self.num_val_batches = 0 + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] # add samplers - self.val_dataloaders = [self.auto_add_sampler(dl, train=False) - for dl in self.val_dataloaders if dl] + dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl] - # determine number of validation batches - # val datasets could be none, 1 or 2+ - if self.val_dataloaders is not None: - self._percent_range_check('val_percent_check') + num_batches = 0 - self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders) - self.num_val_batches = int(self.num_val_batches * self.val_percent_check) + # determine number of batches + # datasets could be none, 1 or 2+ + if len(dataloaders) != 0: + for dataloader in dataloaders: + if not _has_len(dataloader): + num_batches = float('inf') + break - def reset_test_dataloader(self, model): - """Dataloaders are provided by the model. + percent_check = getattr(self, f'{mode}_percent_check') - :param model: - """ - if not self.is_overriden('test_step'): - return + if num_batches != float('inf'): + self._percent_range_check(f'{mode}_percent_check') - # get actual loader - self.test_dataloaders = self.request_data_loader(model.test_dataloader) - if not isinstance(self.test_dataloaders, list): - self.test_dataloaders = [self.test_dataloaders] - self.num_test_batches = 0 + num_batches = sum(len(dataloader) for dataloader in dataloaders) + num_batches = int(num_batches * percent_check) + elif percent_check not in (0.0, 1.0): + raise MisconfigurationException( + 'When using an infinite DataLoader (e.g. with an IterableDataset or when ' + f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, ' + f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.') + return num_batches, dataloaders - # add samplers - self.test_dataloaders = [self.auto_add_sampler(dl, train=False) - for dl in self.test_dataloaders if dl] + def reset_val_dataloader(self, model: LightningModule) -> None: + """Resets the validation dataloader and determines the number of batches. - # determine number of test batches - if self.test_dataloaders is not None: - self._percent_range_check('test_percent_check') + Args: + model: The current `LightningModule` + """ + if self.is_overriden('validation_step'): + self.num_val_batches, self.val_dataloaders =\ + self._reset_eval_dataloader(model, 'val') - len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders) - self.num_test_batches = len_sum - self.num_test_batches = int(self.num_test_batches * self.test_percent_check) + def reset_test_dataloader(self, model) -> None: + """Resets the validation dataloader and determines the number of batches. - def request_data_loader(self, data_loader_fx): + Args: + model: The current `LightningModule` """ - Handles downloading data in the GPU or TPU case. + if self.is_overriden('test_step'): + self.num_test_batches, self.test_dataloaders =\ + self._reset_eval_dataloader(model, 'test') - :param data_loader_fx: - :return: + def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: + """Handles downloading data in the GPU or TPU case. + + Args: + dataloader_fx: The bound dataloader getter + + Returns: + The dataloader """ + dataloader = dataloader_fx() + # get the function we'll use to get data if self.use_ddp or self.use_ddp2: - data_loader = data_loader_fx() - # all processes wait until data download has happened dist.barrier() # data download/load on TPU elif self.use_tpu and XLA_AVAILABLE: - data_loader = data_loader_fx() - # all processes wait until data download has happened - torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") + torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders') - # regular start - else: - data_loader = data_loader_fx() - - return data_loader + return dataloader - def determine_data_use_amount(self, train_percent_check, val_percent_check, - test_percent_check, overfit_pct): - """ - Use less data for debugging purposes + def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float, + test_percent_check: float, overfit_pct: float) -> None: + """Use less data for debugging purposes """ self.train_percent_check = train_percent_check self.val_percent_check = val_percent_check self.test_percent_check = test_percent_check if overfit_pct > 0: if overfit_pct > 1: - raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got " - f"{overfit_pct:.3f}.") + raise ValueError( + f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.') self.train_percent_check = overfit_pct self.val_percent_check = overfit_pct diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2e9c4cc6355b4..0bc46b99ea3a1 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -358,9 +358,9 @@ def run_evaluation(self, test_mode: bool = False): # main progress bar will already be closed when testing so initial position is free position = 2 * self.process_position + (not test_mode) desc = 'Testing' if test_mode else 'Validating' - pbar = tqdm(desc=desc, total=max_batches, leave=test_mode, position=position, - disable=not self.show_progress_bar, dynamic_ncols=True, - file=sys.stdout) + total = max_batches if max_batches != float('inf') else None + pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, file=sys.stdout) setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar) # run evaluation diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9691c3c8edac4..220cdef4ec853 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -222,10 +222,6 @@ def get_model(self): def is_function_implemented(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def is_infinite_dataloader(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def run_evaluation(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -309,7 +305,7 @@ def train(self): total_val_batches = 0 is_val_epoch = False - if not self.disable_validation: + if not self.disable_validation and self.num_training_batches != float('inf'): # val can be checked multiple times in epoch is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 val_checks_per_epoch = self.num_training_batches // self.val_check_batch @@ -323,8 +319,8 @@ def train(self): if self.fast_dev_run: # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run num_iterations = 2 - elif self.is_infinite_dataloader(self.train_dataloader): - # for infinite train loader, the progress bar never ends + elif self.total_batches == float('inf'): + # for infinite train or val loader, the progress bar never ends num_iterations = None else: num_iterations = self.total_batches @@ -333,7 +329,7 @@ def train(self): # .reset() doesn't work on disabled progress bar so we should check if not self.main_progress_bar.disable: self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else '' + desc = f'Epoch {epoch + 1}' self.main_progress_bar.set_description(desc) # ----------------- diff --git a/tests/models/__init__.py b/tests/models/__init__.py index f48b6ab4db351..4992e70a2ec30 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -19,6 +19,9 @@ LightValStepFitMultipleDataloadersMixin, LightTrainDataloader, LightTestDataloader, + LightInfTrainDataloader, + LightInfValDataloader, + LightInfTestDataloader, LightTestOptimizerWithSchedulingMixin, LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 1a59cb8576857..fd3f0ddea1b9f 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -213,6 +213,48 @@ def test_dataloader(self): return self._dataloader(train=False) +class CustomInfDataloader: + def __init__(self, dataloader): + self.dataloader = dataloader + self.iter = iter(dataloader) + self.count = 0 + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count >= 50: + raise StopIteration + self.count = self.count + 1 + try: + return next(self.iter) + except StopIteration: + self.iter = iter(self.dataloader) + return next(self.iter) + + +class LightInfTrainDataloader: + """Simple test dataloader.""" + + def train_dataloader(self): + return CustomInfDataloader(self._dataloader(train=True)) + + +class LightInfValDataloader: + """Simple test dataloader.""" + + def val_dataloader(self): + return CustomInfDataloader(self._dataloader(train=False)) + + +class LightInfTestDataloader: + """Simple test dataloader.""" + + def test_dataloader(self): + return CustomInfDataloader(self._dataloader(train=False)) + + class LightEmptyTestStep: """Empty test step.""" diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 00709409890b3..40670dafdf4bf 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -14,9 +14,85 @@ LightValStepFitMultipleDataloadersMixin, LightValStepFitSingleDataloaderMixin, LightTrainDataloader, + LightInfTrainDataloader, + LightInfValDataloader, + LightInfTestDataloader ) +def test_dataloader_config_errors(tmpdir): + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # percent check < 0 + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + train_percent_check=-0.1, + ) + + # fit model + trainer = Trainer(**trainer_options) + + with pytest.raises(ValueError): + trainer.fit(model) + + # percent check > 1 + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + train_percent_check=1.1, + ) + + # fit model + trainer = Trainer(**trainer_options) + + with pytest.raises(ValueError): + trainer.fit(model) + + # int val_check_interval > num batches + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=10000 + ) + + # fit model + trainer = Trainer(**trainer_options) + + with pytest.raises(ValueError): + trainer.fit(model) + + # float val_check_interval > 1 + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=1.1 + ) + + # fit model + trainer = Trainer(**trainer_options) + + with pytest.raises(ValueError): + trainer.fit(model) + + def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" tutils.reset_seed() @@ -96,7 +172,7 @@ class CurrentTestModel( def test_train_dataloaders_passed_to_fit(tmpdir): - """ Verify that train dataloader can be passed to fit """ + """Verify that train dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel(LightTrainDataloader, TestModelBase): @@ -116,7 +192,9 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True)) - results = trainer.fit(model, **fit_options) + result = trainer.fit(model, **fit_options) + + assert result == 1 def test_train_val_dataloaders_passed_to_fit(tmpdir): @@ -146,13 +224,14 @@ class CurrentTestModel( fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) - results = trainer.fit(model, **fit_options) + result = trainer.fit(model, **fit_options) + assert result == 1 assert len(trainer.val_dataloaders) == 1, \ - f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' def test_all_dataloaders_passed_to_fit(tmpdir): - """ Verify train, val & test dataloader can be passed to fit """ + """Verify train, val & test dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel( @@ -181,14 +260,15 @@ class CurrentTestModel( val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) - results = trainer.fit(model, **fit_options) + result = trainer.fit(model, **fit_options) trainer.test() + assert result == 1 assert len(trainer.val_dataloaders) == 1, \ - f"val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ - f"test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" + f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_multiple_dataloaders_passed_to_fit(tmpdir): @@ -224,9 +304,9 @@ class CurrentTestModel( trainer.test() assert len(trainer.val_dataloaders) == 2, \ - f"Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 2, \ - f"Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" + f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_mixing_of_dataloader_options(tmpdir): @@ -265,40 +345,54 @@ class CurrentTestModel( trainer.test() assert len(trainer.val_dataloaders) == 1, \ - f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ - f"test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" + f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_inf_train_dataloader(tmpdir): """Test inf train data loader (e.g. IterableDataset)""" tutils.reset_seed() - class CurrentTestModel(LightningTestModel): - def train_dataloader(self): - dataloader = self._dataloader(train=True) + class CurrentTestModel( + LightInfTrainDataloader, + LightningTestModel + ): + pass - class CustomInfDataLoader: - def __init__(self, dataloader): - self.dataloader = dataloader - self.iter = iter(dataloader) - self.count = 0 + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) - def __iter__(self): - self.count = 0 - return self + # fit model + with pytest.raises(MisconfigurationException): + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=0.5 + ) + trainer.fit(model) - def __next__(self): - if self.count >= 5: - raise StopIteration - self.count = self.count + 1 - try: - return next(self.iter) - except StopIteration: - self.iter = iter(self.dataloader) - return next(self.iter) + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=50 + ) + result = trainer.fit(model) - return CustomInfDataLoader(dataloader) + # verify training completed + assert result == 1 + + +def test_inf_val_dataloader(tmpdir): + """Test inf val data loader (e.g. IterableDataset)""" + tutils.reset_seed() + + class CurrentTestModel( + LightInfValDataloader, + LightningTestModel + ): + pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) @@ -308,17 +402,51 @@ def __next__(self): trainer = Trainer( default_save_path=tmpdir, max_epochs=1, - val_check_interval=0.5 + val_percent_check=0.5 ) trainer.fit(model) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, - val_check_interval=50, + max_epochs=1 + ) + result = trainer.fit(model) + + # verify training completed + assert result == 1 + + +def test_inf_test_dataloader(tmpdir): + """Test inf test data loader (e.g. IterableDataset)""" + tutils.reset_seed() + + class CurrentTestModel( + LightInfTestDataloader, + LightningTestModel, + LightTestFitSingleTestDataloadersMixin + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # fit model + with pytest.raises(MisconfigurationException): + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + test_percent_check=0.5 + ) + trainer.test(model) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1 ) result = trainer.fit(model) + trainer.test(model) # verify training completed assert result == 1