Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support limit_mode_batches (int) for infinite dataloader #2787

Merged
merged 15 commits into from
Aug 5, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))

- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787))

### Changed

- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
Expand Down
21 changes: 15 additions & 6 deletions docs/source/sequences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.

.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
a `hiddens` arg.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.

----------

Expand All @@ -59,10 +59,13 @@ Iterable Datasets
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
option when using sequential data.

.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
(specifying the number of training batches to run before validation) when initializing the Trainer.
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
the validation interval when val_check_interval is less than one.
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
Here mode can be train/val/test.

.. testcode::

Expand All @@ -87,3 +90,9 @@ option when using sequential data.

# Set val_check_interval
trainer = Trainer(val_check_interval=100)

# Set limit_val_batches to 0.0 or 0
trainer = Trainer(limit_val_batches=0.0)

# Set limit_val_batches as an int
trainer = Trainer(limit_val_batches=100)
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')

if 'example_outputs' not in kwargs:
self.eval()
Expand Down
45 changes: 21 additions & 24 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
self._check_batch_limits('limit_train_batches')

if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
if isinstance(self.limit_train_batches, float):
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
else:
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be also 5.1 meaning 510% ?
cc: @PyTorchLightning/core-contributors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda I don't think so. We iterate like this:

# run epoch
        for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
                enumerate(_with_is_last(train_dataloader)), "get_train_batch"
        ):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

so if the loader is exhausted before, this would trigger an stop iteration meaning that the condition will never be True.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still, it would be cleaner to have there val = min(val, 1.0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do _check_batch_limits there to avoid floating values > 1.0

elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')

# determine when to check validation
# if int passed in, val checks that often
Expand All @@ -241,8 +242,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
Expand Down Expand Up @@ -304,24 +304,21 @@ def _reset_eval_dataloader(
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
self._check_batch_limits(f'limit_{mode}_batches')

# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')

if num_batches != float('inf'):
self._check_batch_limits(f'limit_{mode}_batches')

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, float):
num_batches = int(num_batches * limit_eval_batches)
else:
num_batches = min(len(dataloader), limit_eval_batches)

elif limit_eval_batches not in (0.0, 1.0):
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _adjust_batch_size(trainer,
if hasattr(model, batch_arg_name):
setattr(model, batch_arg_name, value)
else:
setattr(model.hparams, batch_arg_name, value)
setattr(model.hparams, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_onnx_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_error_if_no_input(tmpdir):
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onxx")
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
model.to_onnx(file_path)


Expand Down
79 changes: 71 additions & 8 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,69 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
pytest.param(1.0, 1.0, 1.0),
]
)
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results == 1
assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf')
assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf')

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf')


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0, 0, 0),
pytest.param(10, 10, 10),
]
)
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches[0] == limit_val_batches

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == limit_test_batches


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
Expand All @@ -266,7 +329,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand Down Expand Up @@ -307,7 +370,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
Expand Down Expand Up @@ -436,7 +499,7 @@ def test_train_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -447,7 +510,7 @@ def test_val_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -458,7 +521,7 @@ def test_test_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)


Expand Down Expand Up @@ -774,7 +837,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -785,7 +848,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -796,5 +859,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)