Skip to content

Commit

Permalink
Add support for IterableDatasets everywhere (#1104)
Browse files Browse the repository at this point in the history
* Add support for IterableDatasets everywhere

* Added type hints, simplified code and improved coverage in data_loading.py

* Update CHANGELOG.md
  • Loading branch information
ethanwharris committed Mar 12, 2020
1 parent 1383f64 commit 2b3f443
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 142 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))

### Changed

Expand Down
199 changes: 105 additions & 94 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable

import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
Expand All @@ -23,6 +25,15 @@
XLA_AVAILABLE = True


def _has_len(dataloader: DataLoader) -> bool:
try:
# try getting the length
_ = len(dataloader)
return True
except TypeError:
return False


class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
Expand All @@ -35,27 +46,30 @@ class TrainerDataLoadingMixin(ABC):
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: int
num_training_batches: Union[int, float]
val_check_batch: ...
val_dataloaders: DataLoader
num_val_batches: int
test_dataloaders: DataLoader
num_test_batches: int
val_dataloaders: List[DataLoader]
num_val_batches: Union[int, float]
test_dataloaders: List[DataLoader]
num_test_batches: Union[int, float]
train_percent_check: float
val_percent_check: float
test_percent_check: float

@abstractmethod
def is_overriden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def _percent_range_check(self, name):
def _percent_range_check(self, name: str) -> None:
value = getattr(self, name)
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
if name == "val_check_interval":
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.'
if name == 'val_check_interval':
msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.'

if not 0. <= value <= 1.:
raise ValueError(msg)

def auto_add_sampler(self, dataloader, train):
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
Expand Down Expand Up @@ -88,22 +102,22 @@ def auto_add_sampler(self, dataloader, train):
dataloader = DataLoader(**dl_args)
return dataloader

def reset_train_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
self.train_dataloader = self.request_data_loader(model.train_dataloader)
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self._percent_range_check('train_percent_check')

if self.is_infinite_dataloader(self.train_dataloader):
if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
Expand All @@ -117,122 +131,119 @@ def reset_train_dataloader(self, model):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
else:
if self.is_infinite_dataloader(self.train_dataloader):
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
must be an int. An int k specifies checking validation every k training batches.
'''
raise MisconfigurationException(m)
if not _has_len(self.train_dataloader):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
'validation every k training batches.')

self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def is_infinite_dataloader(self, dataloader):
try:
# try getting the length
_ = len(dataloader)
return False
except TypeError as e:
return True
def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
def reset_val_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
Returns:
Tuple (num_batches, dataloaders)
"""
if not self.is_overriden('validation_step'):
return
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))

self.val_dataloaders = self.request_data_loader(model.val_dataloader)
if not isinstance(self.val_dataloaders, list):
self.val_dataloaders = [self.val_dataloaders]
self.num_val_batches = 0
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

# add samplers
self.val_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.val_dataloaders if dl]
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl]

# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.val_dataloaders is not None:
self._percent_range_check('val_percent_check')
num_batches = 0

self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders)
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for dataloader in dataloaders:
if not _has_len(dataloader):
num_batches = float('inf')
break

def reset_test_dataloader(self, model):
"""Dataloaders are provided by the model.
percent_check = getattr(self, f'{mode}_percent_check')

:param model:
"""
if not self.is_overriden('test_step'):
return
if num_batches != float('inf'):
self._percent_range_check(f'{mode}_percent_check')

# get actual loader
self.test_dataloaders = self.request_data_loader(model.test_dataloader)
if not isinstance(self.test_dataloaders, list):
self.test_dataloaders = [self.test_dataloaders]
self.num_test_batches = 0
num_batches = sum(len(dataloader) for dataloader in dataloaders)
num_batches = int(num_batches * percent_check)
elif percent_check not in (0.0, 1.0):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, '
f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
return num_batches, dataloaders

# add samplers
self.test_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.test_dataloaders if dl]
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.
# determine number of test batches
if self.test_dataloaders is not None:
self._percent_range_check('test_percent_check')
Args:
model: The current `LightningModule`
"""
if self.is_overriden('validation_step'):
self.num_val_batches, self.val_dataloaders =\
self._reset_eval_dataloader(model, 'val')

len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders)
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.
def request_data_loader(self, data_loader_fx):
Args:
model: The current `LightningModule`
"""
Handles downloading data in the GPU or TPU case.
if self.is_overriden('test_step'):
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')

:param data_loader_fx:
:return:
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
Args:
dataloader_fx: The bound dataloader getter
Returns:
The dataloader
"""
dataloader = dataloader_fx()

# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
data_loader = data_loader_fx()

# all processes wait until data download has happened
dist.barrier()

# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
data_loader = data_loader_fx()

# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders")
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')

# regular start
else:
data_loader = data_loader_fx()

return data_loader
return dataloader

def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct):
"""
Use less data for debugging purposes
def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
test_percent_check: float, overfit_pct: float) -> None:
"""Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
f"{overfit_pct:.3f}.")
raise ValueError(
f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.')

self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@ def run_evaluation(self, test_mode: bool = False):
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test_mode)
desc = 'Testing' if test_mode else 'Validating'
pbar = tqdm(desc=desc, total=max_batches, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
file=sys.stdout)
total = max_batches if max_batches != float('inf') else None
pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True, file=sys.stdout)
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)

# run evaluation
Expand Down
12 changes: 4 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,6 @@ def get_model(self):
def is_function_implemented(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def is_infinite_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_evaluation(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down Expand Up @@ -309,7 +305,7 @@ def train(self):

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
if not self.disable_validation and self.num_training_batches != float('inf'):
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
Expand All @@ -323,8 +319,8 @@ def train(self):
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_infinite_dataloader(self.train_dataloader):
# for infinite train loader, the progress bar never ends
elif self.total_batches == float('inf'):
# for infinite train or val loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
Expand All @@ -333,7 +329,7 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
desc = f'Epoch {epoch + 1}'
self.main_progress_bar.set_description(desc)

# -----------------
Expand Down
3 changes: 3 additions & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
Expand Down
Loading

0 comments on commit 2b3f443

Please sign in to comment.