Skip to content

Commit

Permalink
seperate exception for IterableDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Aug 2, 2020
1 parent 107ba83 commit d71306c
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,20 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self._worker_check(self.train_dataloader, 'train dataloader')
self._check_batch_limits('limit_train_batches')

if isinstance(self.limit_train_batches, int):
self.num_training_batches = min(self.num_training_batches, self.limit_train_batches)
# limit num batches either as a percent or num steps
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
else:
if self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches not in (0.0, 1.0):
elif self.limit_train_batches != 1.0 and _has_iterable_dataset(self.train_dataloader):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or `int`')
'When using an Iterable Dataset for `limit_train_batches`,`'
' `Trainer(limit_train_batches)` must be `1.0` or an `int` value.')
elif self.limit_train_batches != 0.0:
raise MisconfigurationException(
'When using a Dataloader that does not implement `__len__` for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0` or an `int` value')

# determine when to check validation
# if int passed in, val checks that often
Expand Down Expand Up @@ -309,16 +313,19 @@ def _reset_eval_dataloader(
limit_eval_batches = getattr(self, f'limit_{mode}_batches')

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int):
num_batches = min(num_batches, limit_eval_batches)
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
else:
if num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches not in (0.0, 1.0):
elif limit_eval_batches != 1.0 and _has_iterable_dataset(dataloader):
raise MisconfigurationException(
'When using an Iterable Dataset for `limit_{mode}_batches`,`'
f' `Trainer(limit_{mode}_batches)` must be `1.0` or an `int` value.')
elif limit_eval_batches != 0.0:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or `int`')
'When using a Dataloader that does not implement `__len__` for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0` or an `int` value')

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
Expand Down

0 comments on commit d71306c

Please sign in to comment.