Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support limit_mode_batches (int) for infinite dataloader #2840

Merged
merged 23 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))

- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2840](https://github.com/PyTorchLightning/pytorch-lightning/pull/2840))

- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935))

### Changed
Expand Down
21 changes: 15 additions & 6 deletions docs/source/sequences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.

.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
a `hiddens` arg.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.

----------

Expand All @@ -59,10 +59,13 @@ Iterable Datasets
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
option when using sequential data.

.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
(specifying the number of training batches to run before validation) when initializing the Trainer.
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
the validation interval when val_check_interval is less than one.
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
Here mode can be train/val/test.

.. testcode::

Expand All @@ -87,3 +90,9 @@ option when using sequential data.

# Set val_check_interval
trainer = Trainer(val_check_interval=100)

# Set limit_val_batches to 0.0 or 0
trainer = Trainer(limit_val_batches=0.0)

# Set limit_val_batches as an int
trainer = Trainer(limit_val_batches=100)
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError('input_sample and example_input_array tensors are both missing.')
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')

if 'example_outputs' not in kwargs:
self.eval()
Expand Down
65 changes: 20 additions & 45 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,6 @@ class TrainerDataLoadingMixin(ABC):
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def _check_batch_limits(self, name: str) -> None:
# TODO: verify it is still needed and deprecate it..
value = getattr(self, name)

# ints are fine
if isinstance(value, int):
return

msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}. (or pass in an int)'
if name == 'val_check_interval':
msg += ' If you want to disable validation set `limit_val_batches` to 0.0 instead.'

if not 0. <= value <= 1.:
raise ValueError(msg)

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
on_windows = platform.system() == 'Windows'

Expand Down Expand Up @@ -212,18 +197,18 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
self._check_batch_limits('limit_train_batches')

if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
if isinstance(self.limit_train_batches, float):
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
else:
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')

# determine when to check validation
# if int passed in, val checks that often
Expand All @@ -241,13 +226,10 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
self._check_batch_limits('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

Expand Down Expand Up @@ -308,20 +290,16 @@ def _reset_eval_dataloader(
# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')

if num_batches != float('inf'):
self._check_batch_limits(f'limit_{mode}_batches')

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, float):
num_batches = int(num_batches * limit_eval_batches)
else:
num_batches = min(len(dataloader), limit_eval_batches)

elif limit_eval_batches not in (0.0, 1.0):
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
Borda marked this conversation as resolved.
Show resolved Hide resolved
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
Expand Down Expand Up @@ -388,9 +366,6 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
if isinstance(overfit_batches, float) and overfit_batches > 1:
raise ValueError('`overfit_batches` when used as a percentage must'
f' be in range 0.0 < x < 1.0 but got {overfit_batches:.3f}.')
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches
16 changes: 7 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,6 @@ def __init__(
# logging
self.configure_logger(logger)
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
self.row_log_interval = row_log_interval

# how much of the data to use
Expand All @@ -547,9 +546,6 @@ def __init__(
)
overfit_batches = overfit_pct

# convert floats to ints
self.overfit_batches = _determine_limit_batches(overfit_batches)

# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
Expand Down Expand Up @@ -577,9 +573,11 @@ def __init__(
)
limit_train_batches = train_percent_check

self.limit_test_batches = _determine_limit_batches(limit_test_batches)
self.limit_val_batches = _determine_limit_batches(limit_val_batches)
self.limit_train_batches = _determine_limit_batches(limit_train_batches)
self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.overfit_batches)

# AMP init
Expand Down Expand Up @@ -1430,12 +1428,12 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _determine_limit_batches(batches: Union[int, float]) -> Union[int, float]:
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches}, it has to be in (0, 1) or nature number.'
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)
2 changes: 1 addition & 1 deletion tests/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_error_if_no_input(tmpdir):
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onxx")
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
model.to_onnx(file_path)


Expand Down
82 changes: 69 additions & 13 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,15 @@ def test_fit_val_loader_only(tmpdir):


@pytest.mark.parametrize("dataloader_options", [
dict(val_check_interval=1.1),
dict(val_check_interval=10000),
])
def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
**dataloader_options,
)

with pytest.raises(ValueError):
# fit model
trainer.fit(model)
Expand All @@ -78,9 +74,13 @@ def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
dict(limit_val_batches=1.2),
dict(limit_test_batches=-0.1),
dict(limit_test_batches=1.2),
dict(val_check_interval=-0.1),
dict(val_check_interval=1.2),
dict(overfit_batches=-0.1),
dict(overfit_batches=1.2),
])
def test_dataloader_config_errors_init(tmpdir, dataloader_options):
with pytest.raises(MisconfigurationException):
with pytest.raises(MisconfigurationException, match='passed invalid value'):
Trainer(
default_root_dir=tmpdir,
max_epochs=1,
Expand Down Expand Up @@ -256,6 +256,62 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
pytest.param(0.0, 0.0, 0.0),
pytest.param(1.0, 1.0, 1.0),
])
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results == 1
assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf'))
assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf'))

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf'))


@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
pytest.param(0, 0, 0),
pytest.param(10, 10, 10),
])
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches[0] == limit_val_batches

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == limit_test_batches


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
Expand All @@ -266,7 +322,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand Down Expand Up @@ -307,7 +363,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
Expand Down Expand Up @@ -436,7 +492,7 @@ def test_train_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -447,7 +503,7 @@ def test_val_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -458,7 +514,7 @@ def test_test_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)


Expand Down Expand Up @@ -774,7 +830,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -785,7 +841,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -796,5 +852,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)