Skip to content

Commit

Permalink
check
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Aug 5, 2020
1 parent dea47b1 commit 0246e3c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/source/sequences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ Iterable Datasets
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
option when using sequential data.

.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or an int
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
interval when ``val_check_interval`` is less than one. Similarly, you can set limit_{mode}_batches to a float or
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
Here mode can be train/val/test.
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,12 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = min(1.0, self.num_training_batches)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' num_training_batches to use.')
' `num_training_batches` to use.')

# determine when to check validation
# if int passed in, val checks that often
Expand Down Expand Up @@ -305,6 +304,7 @@ def _reset_eval_dataloader(
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
self._check_batch_limits(f'limit_{mode}_batches')

# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
Expand All @@ -313,7 +313,6 @@ def _reset_eval_dataloader(
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = min(1.0, num_batches)
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
Expand Down

0 comments on commit 0246e3c

Please sign in to comment.