Skip to content

Commit

Permalink
fix for pyTorch 1.2 (#549)
Browse files Browse the repository at this point in the history
* min pytorch 1.2

* fix IterableDataset

* upgrade torchvision

* fix msg
  • Loading branch information
Borda authored and williamFalcon committed Nov 26, 2019
1 parent 55f3ffd commit f2191b0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
17 changes: 14 additions & 3 deletions pytorch_lightning/trainer/data_loading_mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import warnings

import torch.distributed as dist
from torch.utils.data import IterableDataset
try:
# loading for pyTorch 1.3
from torch.utils.data import IterableDataset
except ImportError:
# loading for pyTorch 1.1
import torch
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
EXIST_ITER_DATASET = False
else:
EXIST_ITER_DATASET = True
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.utilities.debugging import MisconfigurationException
Expand All @@ -24,7 +34,7 @@ def init_train_dataloader(self, model):
self.get_train_dataloader = model.train_dataloader

# determine number of training batches
if isinstance(self.get_train_dataloader().dataset, IterableDataset):
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset):
self.nb_training_batches = float('inf')
else:
self.nb_training_batches = len(self.get_train_dataloader())
Expand Down Expand Up @@ -167,7 +177,8 @@ def get_dataloaders(self, model):
self.get_val_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader().dataset, IterableDataset)
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset))
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for train_dataloader,
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
scikit-learn>=0.20.2
tqdm>=4.35.0
numpy>=1.16.4
torch>=1.1
torchvision>=0.3.0
torch>=1.2
torchvision>=0.4.0
pandas>=0.24 # lower version do not support py3.7
test-tube>=0.6.9
# future>=0.17.1 # required for buildins in setup.py

2 comments on commit f2191b0

@svdc
Copy link

@svdc svdc commented on f2191b0 Nov 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any point in having separate functionality for pytorch <1.2 when requirements.txt forces >=1.2?

@Borda
Copy link
Member Author

@Borda Borda commented on f2191b0 Nov 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so fair we rely on test-tube for testing and as default logger and it crashes with PyTorch 1.1 #552

Please sign in to comment.