Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for IterableDatasets everywhere #1104

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))

### Changed

Expand Down
199 changes: 105 additions & 94 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable

import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
Expand All @@ -23,6 +25,15 @@
XLA_AVAILABLE = True


def _has_len(dataloader: DataLoader) -> bool:
try:
# try getting the length
_ = len(dataloader)
return True
except TypeError:
return False


class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
Expand All @@ -35,27 +46,30 @@ class TrainerDataLoadingMixin(ABC):
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: int
num_training_batches: Union[int, float]
val_check_batch: ...
val_dataloaders: DataLoader
num_val_batches: int
test_dataloaders: DataLoader
num_test_batches: int
val_dataloaders: List[DataLoader]
num_val_batches: Union[int, float]
test_dataloaders: List[DataLoader]
num_test_batches: Union[int, float]
train_percent_check: float
val_percent_check: float
test_percent_check: float

@abstractmethod
def is_overriden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def _percent_range_check(self, name):
def _percent_range_check(self, name: str) -> None:
value = getattr(self, name)
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
if name == "val_check_interval":
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.'
if name == 'val_check_interval':
msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.'

if not 0. <= value <= 1.:
raise ValueError(msg)

def auto_add_sampler(self, dataloader, train):
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
Expand Down Expand Up @@ -88,22 +102,22 @@ def auto_add_sampler(self, dataloader, train):
dataloader = DataLoader(**dl_args)
return dataloader

def reset_train_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).

self.train_dataloader = self.request_data_loader(model.train_dataloader)
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self._percent_range_check('train_percent_check')

if self.is_infinite_dataloader(self.train_dataloader):
if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
Expand All @@ -117,122 +131,119 @@ def reset_train_dataloader(self, model):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
else:
if self.is_infinite_dataloader(self.train_dataloader):
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
must be an int. An int k specifies checking validation every k training batches.
'''
raise MisconfigurationException(m)
if not _has_len(self.train_dataloader):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
'validation every k training batches.')

self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def is_infinite_dataloader(self, dataloader):
try:
# try getting the length
_ = len(dataloader)
return False
except TypeError as e:
return True
def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.

def reset_val_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`

Returns:
Tuple (num_batches, dataloaders)
"""
if not self.is_overriden('validation_step'):
return
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))

self.val_dataloaders = self.request_data_loader(model.val_dataloader)
if not isinstance(self.val_dataloaders, list):
self.val_dataloaders = [self.val_dataloaders]
self.num_val_batches = 0
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

# add samplers
self.val_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.val_dataloaders if dl]
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl]

# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.val_dataloaders is not None:
self._percent_range_check('val_percent_check')
num_batches = 0

self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders)
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for dataloader in dataloaders:
if not _has_len(dataloader):
num_batches = float('inf')
break

def reset_test_dataloader(self, model):
"""Dataloaders are provided by the model.
percent_check = getattr(self, f'{mode}_percent_check')

:param model:
"""
if not self.is_overriden('test_step'):
return
if num_batches != float('inf'):
self._percent_range_check(f'{mode}_percent_check')

# get actual loader
self.test_dataloaders = self.request_data_loader(model.test_dataloader)
if not isinstance(self.test_dataloaders, list):
self.test_dataloaders = [self.test_dataloaders]
self.num_test_batches = 0
num_batches = sum(len(dataloader) for dataloader in dataloaders)
num_batches = int(num_batches * percent_check)
elif percent_check not in (0.0, 1.0):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, '
f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
return num_batches, dataloaders

# add samplers
self.test_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.test_dataloaders if dl]
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.

# determine number of test batches
if self.test_dataloaders is not None:
self._percent_range_check('test_percent_check')
Args:
model: The current `LightningModule`
"""
if self.is_overriden('validation_step'):
self.num_val_batches, self.val_dataloaders =\
self._reset_eval_dataloader(model, 'val')

len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders)
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.

def request_data_loader(self, data_loader_fx):
Args:
model: The current `LightningModule`
"""
Handles downloading data in the GPU or TPU case.
if self.is_overriden('test_step'):
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')

:param data_loader_fx:
:return:
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.

Args:
dataloader_fx: The bound dataloader getter

Returns:
The dataloader
"""
dataloader = dataloader_fx()

# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
data_loader = data_loader_fx()

# all processes wait until data download has happened
dist.barrier()

# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
data_loader = data_loader_fx()

# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders")
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')

# regular start
else:
data_loader = data_loader_fx()

return data_loader
return dataloader

def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct):
"""
Use less data for debugging purposes
def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
test_percent_check: float, overfit_pct: float) -> None:
"""Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
f"{overfit_pct:.3f}.")
raise ValueError(
f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.')

self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ def run_evaluation(self, test_mode: bool = False):
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test_mode)
desc = 'Testing' if test_mode else 'Validating'
pbar = tqdm(desc=desc, total=max_batches, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
file=sys.stdout)
total = max_batches if max_batches != float('inf') else None
pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True, file=sys.stdout)
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)

# run evaluation
Expand Down
12 changes: 4 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,6 @@ def get_model(self):
def is_function_implemented(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def is_infinite_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_evaluation(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down Expand Up @@ -310,7 +306,7 @@ def train(self):

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
if not self.disable_validation and self.num_training_batches != float('inf'):
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
Expand All @@ -324,8 +320,8 @@ def train(self):
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_infinite_dataloader(self.train_dataloader):
# for infinite train loader, the progress bar never ends
elif self.total_batches == float('inf'):
# for infinite train or val loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
Expand All @@ -334,7 +330,7 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
desc = f'Epoch {epoch + 1}'
self.main_progress_bar.set_description(desc)

# -----------------
Expand Down
3 changes: 3 additions & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
Expand Down
Loading