Skip to content

Commit

Permalink
Support limit_mode_batches (int) for infinite dataloader (#2787)
Browse files Browse the repository at this point in the history
* Support limit_mode_batches(int) for infinite dataloader

* flake8

* revert and update

* add and update tests

* pep8

* chlog

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Add suggestions by @awaelchli

* docs

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Apply suggestions from code review

* fix

* max

* check

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
  • Loading branch information
6 people committed Aug 5, 2020
1 parent b2a7d75 commit de9c9f0
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 41 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))

- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787))

### Changed

- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
Expand Down
21 changes: 15 additions & 6 deletions docs/source/sequences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.

.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
a `hiddens` arg.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.

----------

Expand All @@ -59,10 +59,13 @@ Iterable Datasets
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
option when using sequential data.

.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
(specifying the number of training batches to run before validation) when initializing the Trainer.
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
the validation interval when val_check_interval is less than one.
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
Here mode can be train/val/test.

.. testcode::

Expand All @@ -87,3 +90,9 @@ option when using sequential data.

# Set val_check_interval
trainer = Trainer(val_check_interval=100)

# Set limit_val_batches to 0.0 or 0
trainer = Trainer(limit_val_batches=0.0)

# Set limit_val_batches as an int
trainer = Trainer(limit_val_batches=100)
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')

if 'example_outputs' not in kwargs:
self.eval()
Expand Down
45 changes: 21 additions & 24 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
self._check_batch_limits('limit_train_batches')

if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
if isinstance(self.limit_train_batches, float):
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
else:
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')

# determine when to check validation
# if int passed in, val checks that often
Expand All @@ -241,8 +242,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
Expand Down Expand Up @@ -304,24 +304,21 @@ def _reset_eval_dataloader(
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
self._check_batch_limits(f'limit_{mode}_batches')

# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')

if num_batches != float('inf'):
self._check_batch_limits(f'limit_{mode}_batches')

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, float):
num_batches = int(num_batches * limit_eval_batches)
else:
num_batches = min(len(dataloader), limit_eval_batches)

elif limit_eval_batches not in (0.0, 1.0):
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _adjust_batch_size(trainer,
if hasattr(model, batch_arg_name):
setattr(model, batch_arg_name, value)
else:
setattr(model.hparams, batch_arg_name, value)
setattr(model.hparams, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_onnx_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_error_if_no_input(tmpdir):
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onxx")
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
model.to_onnx(file_path)


Expand Down
79 changes: 71 additions & 8 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,69 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
pytest.param(1.0, 1.0, 1.0),
]
)
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results == 1
assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf')
assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf')

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf')


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0, 0, 0),
pytest.param(10, 10, 10),
]
)
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches[0] == limit_val_batches

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == limit_test_batches


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
Expand All @@ -266,7 +329,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand Down Expand Up @@ -307,7 +370,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
Expand Down Expand Up @@ -436,7 +499,7 @@ def test_train_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -447,7 +510,7 @@ def test_val_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -458,7 +521,7 @@ def test_test_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)


Expand Down Expand Up @@ -774,7 +837,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -785,7 +848,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -796,5 +859,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)

0 comments on commit de9c9f0

Please sign in to comment.