Skip to content

Commit

Permalink
fixing bug in testing for IterableDataset (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeScarp authored and williamFalcon committed Nov 26, 2019
1 parent 4627887 commit 55f3ffd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/data_loading_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def init_train_dataloader(self, model):
self.get_train_dataloader = model.train_dataloader

# determine number of training batches
if isinstance(self.get_train_dataloader(), IterableDataset):
if isinstance(self.get_train_dataloader().dataset, IterableDataset):
self.nb_training_batches = float('inf')
else:
self.nb_training_batches = len(self.get_train_dataloader())
Expand Down Expand Up @@ -167,7 +167,7 @@ def get_dataloaders(self, model):
self.get_val_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader().dataset, IterableDataset)
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for train_dataloader,
Expand Down

0 comments on commit 55f3ffd

Please sign in to comment.