diff --git a/CHANGELOG.md b/CHANGELOG.md index eed7f8e7955d0..f29c17ddfea2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562)) +- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2840](https://github.com/PyTorchLightning/pytorch-lightning/pull/2840)) + - Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935)) ### Changed diff --git a/docs/source/sequences.rst b/docs/source/sequences.rst index e24ee5bbca1cc..b9a8f2ee642aa 100644 --- a/docs/source/sequences.rst +++ b/docs/source/sequences.rst @@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag. .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. -.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include - a `hiddens` arg. +.. note:: Using this feature requires updating your LightningModule's + :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg. ---------- @@ -59,10 +59,13 @@ Iterable Datasets Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural option when using sequential data. -.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int - (specifying the number of training batches to run before validation) when initializing the Trainer. - This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate - the validation interval when val_check_interval is less than one. +.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int + (specifying the number of training batches to run before validation) when initializing the Trainer. This is + because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation + interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or + an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches`` + to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. + Here mode can be train/val/test. .. testcode:: @@ -87,3 +90,9 @@ option when using sequential data. # Set val_check_interval trainer = Trainer(val_check_interval=100) + + # Set limit_val_batches to 0.0 or 0 + trainer = Trainer(limit_val_batches=0.0) + + # Set limit_val_batches as an int + trainer = Trainer(limit_val_batches=100) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d7a503c07aca4..d11c8ed9e70af 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1771,7 +1771,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg elif self.example_input_array is not None: input_data = self.example_input_array else: - raise ValueError('input_sample and example_input_array tensors are both missing.') + raise ValueError('`input_sample` and `example_input_array` tensors are both missing.') if 'example_outputs' not in kwargs: self.eval() diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 09186765c6eee..38a1118118a40 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -103,21 +103,6 @@ class TrainerDataLoadingMixin(ABC): def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def _check_batch_limits(self, name: str) -> None: - # TODO: verify it is still needed and deprecate it.. - value = getattr(self, name) - - # ints are fine - if isinstance(value, int): - return - - msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}. (or pass in an int)' - if name == 'val_check_interval': - msg += ' If you want to disable validation set `limit_val_batches` to 0.0 instead.' - - if not 0. <= value <= 1.: - raise ValueError(msg) - def _worker_check(self, dataloader: DataLoader, name: str) -> None: on_windows = platform.system() == 'Windows' @@ -212,18 +197,18 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) + self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') - self._check_batch_limits('limit_train_batches') - if not _has_len(self.train_dataloader): - self.num_training_batches = float('inf') - else: - # try getting the length - if isinstance(self.limit_train_batches, float): - self.num_training_batches = len(self.train_dataloader) - self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) - else: - self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches) + if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: + self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) + elif self.num_training_batches != float('inf'): + self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) + elif self.limit_train_batches != 1.0: + raise MisconfigurationException( + 'When using an IterableDataset for `limit_train_batches`,' + ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' + ' `num_training_batches` to use.') # determine when to check validation # if int passed in, val checks that often @@ -241,13 +226,10 @@ def reset_train_dataloader(self, model: LightningModule) -> None: self.val_check_batch = float('inf') else: raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset' - ' or when DataLoader does not implement `__len__`) for `train_dataloader`,' + 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.') else: - self._check_batch_limits('val_check_interval') - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) @@ -308,20 +290,16 @@ def _reset_eval_dataloader( # percent or num_steps limit_eval_batches = getattr(self, f'limit_{mode}_batches') - if num_batches != float('inf'): - self._check_batch_limits(f'limit_{mode}_batches') - - # limit num batches either as a percent or num steps - if isinstance(limit_eval_batches, float): - num_batches = int(num_batches * limit_eval_batches) - else: - num_batches = min(len(dataloader), limit_eval_batches) - - elif limit_eval_batches not in (0.0, 1.0): + # limit num batches either as a percent or num steps + if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: + num_batches = min(num_batches, int(limit_eval_batches)) + elif num_batches != float('inf'): + num_batches = int(num_batches * limit_eval_batches) + elif limit_eval_batches != 1.0: raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset' - f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,' - f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.') + 'When using an IterableDataset for `limit_{mode}_batches`,' + f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies' + f' `num_{mode}_batches` to use.') if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) @@ -388,9 +366,6 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: def determine_data_use_amount(self, overfit_batches: float) -> None: """Use less data for debugging purposes""" if overfit_batches > 0: - if isinstance(overfit_batches, float) and overfit_batches > 1: - raise ValueError('`overfit_batches` when used as a percentage must' - f' be in range 0.0 < x < 1.0 but got {overfit_batches:.3f}.') self.limit_train_batches = overfit_batches self.limit_val_batches = overfit_batches self.limit_test_batches = overfit_batches diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b342328df297..ea8d7941d95f6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -534,7 +534,6 @@ def __init__( # logging self.configure_logger(logger) self.log_save_interval = log_save_interval - self.val_check_interval = val_check_interval self.row_log_interval = row_log_interval # how much of the data to use @@ -547,9 +546,6 @@ def __init__( ) overfit_batches = overfit_pct - # convert floats to ints - self.overfit_batches = _determine_limit_batches(overfit_batches) - # TODO: remove in 0.10.0 if val_percent_check is not None: rank_zero_warn( @@ -577,9 +573,11 @@ def __init__( ) limit_train_batches = train_percent_check - self.limit_test_batches = _determine_limit_batches(limit_test_batches) - self.limit_val_batches = _determine_limit_batches(limit_val_batches) - self.limit_train_batches = _determine_limit_batches(limit_train_batches) + self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') + self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') + self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') + self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') + self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.overfit_batches) # AMP init @@ -1430,12 +1428,12 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader -def _determine_limit_batches(batches: Union[int, float]) -> Union[int, float]: +def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: if 0 <= batches <= 1: return batches elif batches > 1 and batches % 1.0 == 0: return int(batches) else: raise MisconfigurationException( - f'You have passed invalid value {batches}, it has to be in (0, 1) or nature number.' + f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.' ) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index d7cc7cffaec3f..278465941a043 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -85,7 +85,7 @@ def test_error_if_no_input(tmpdir): model = EvalModelTemplate() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onxx") - with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'): + with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'): model.to_onnx(file_path) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1c7e21b7a72bb..d9e2500707fc8 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -53,19 +53,15 @@ def test_fit_val_loader_only(tmpdir): @pytest.mark.parametrize("dataloader_options", [ - dict(val_check_interval=1.1), dict(val_check_interval=10000), ]) def test_dataloader_config_errors_runtime(tmpdir, dataloader_options): - model = EvalModelTemplate() - trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, **dataloader_options, ) - with pytest.raises(ValueError): # fit model trainer.fit(model) @@ -78,9 +74,13 @@ def test_dataloader_config_errors_runtime(tmpdir, dataloader_options): dict(limit_val_batches=1.2), dict(limit_test_batches=-0.1), dict(limit_test_batches=1.2), + dict(val_check_interval=-0.1), + dict(val_check_interval=1.2), + dict(overfit_batches=-0.1), + dict(overfit_batches=1.2), ]) def test_dataloader_config_errors_init(tmpdir, dataloader_options): - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match='passed invalid value'): Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -256,6 +256,62 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + pytest.param(0.0, 0.0, 0.0), + pytest.param(1.0, 1.0, 1.0), +]) +def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" + model = EvalModelTemplate() + model.train_dataloader = model.train_dataloader__infinite + model.val_dataloader = model.val_dataloader__infinite + model.test_dataloader = model.test_dataloader__infinite + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, + ) + + results = trainer.fit(model) + assert results == 1 + assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) + assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf')) + + +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + pytest.param(0, 0, 0), + pytest.param(10, 10, 10), +]) +def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + model = EvalModelTemplate() + model.train_dataloader = model.train_dataloader__infinite + model.val_dataloader = model.val_dataloader__infinite + model.test_dataloader = model.test_dataloader__infinite + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, + ) + + results = trainer.fit(model) + assert results + assert trainer.num_training_batches == limit_train_batches + assert trainer.num_val_batches[0] == limit_val_batches + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches[0] == limit_test_batches + + @pytest.mark.parametrize( ['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ @@ -266,7 +322,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): ] ) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify num_batches for val & test dataloaders passed with batch limit in percent""" + """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -307,7 +363,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim ] ) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify num_batches for val & test dataloaders passed with batch limit as number""" + """Verify num_batches for train, val & test dataloaders passed with batch limit as number""" os.environ['PL_DEV_DEBUG'] = '1' model = EvalModelTemplate() @@ -436,7 +492,7 @@ def test_train_inf_dataloader_error(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.fit(model) @@ -447,7 +503,7 @@ def test_val_inf_dataloader_error(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.fit(model) @@ -458,7 +514,7 @@ def test_test_inf_dataloader_error(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.test(model) @@ -774,7 +830,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.fit(model) @@ -785,7 +841,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.fit(model) @@ -796,5 +852,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.test(model)