From f2191b0cdf4305ae3a5ad2b1e404f99764a1a7c6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 26 Nov 2019 16:58:50 +0100 Subject: [PATCH] fix for pyTorch 1.2 (#549) * min pytorch 1.2 * fix IterableDataset * upgrade torchvision * fix msg --- pytorch_lightning/trainer/data_loading_mixin.py | 17 ++++++++++++++--- requirements.txt | 4 ++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading_mixin.py b/pytorch_lightning/trainer/data_loading_mixin.py index 52541a86fc0cb..8533fdf4851d5 100644 --- a/pytorch_lightning/trainer/data_loading_mixin.py +++ b/pytorch_lightning/trainer/data_loading_mixin.py @@ -1,7 +1,17 @@ import warnings import torch.distributed as dist -from torch.utils.data import IterableDataset +try: + # loading for pyTorch 1.3 + from torch.utils.data import IterableDataset +except ImportError: + # loading for pyTorch 1.1 + import torch + warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,' + ' please upgrade to 1.2+' % torch.__version__, ImportWarning) + EXIST_ITER_DATASET = False +else: + EXIST_ITER_DATASET = True from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.utilities.debugging import MisconfigurationException @@ -24,7 +34,7 @@ def init_train_dataloader(self, model): self.get_train_dataloader = model.train_dataloader # determine number of training batches - if isinstance(self.get_train_dataloader().dataset, IterableDataset): + if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): self.nb_training_batches = float('inf') else: self.nb_training_batches = len(self.get_train_dataloader()) @@ -167,7 +177,8 @@ def get_dataloaders(self, model): self.get_val_dataloaders() # support IterableDataset for train data - self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader().dataset, IterableDataset) + self.is_iterable_train_dataloader = ( + EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset)) if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): m = ''' When using an iterableDataset for train_dataloader, diff --git a/requirements.txt b/requirements.txt index 53f623e76467d..acdc11d904516 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ scikit-learn>=0.20.2 tqdm>=4.35.0 numpy>=1.16.4 -torch>=1.1 -torchvision>=0.3.0 +torch>=1.2 +torchvision>=0.4.0 pandas>=0.24 # lower version do not support py3.7 test-tube>=0.6.9 # future>=0.17.1 # required for buildins in setup.py \ No newline at end of file