From b18accc64ccd24095c11fdbd64cc924456134592 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 5 Apr 2020 16:07:16 +0100 Subject: [PATCH] Add warning for few workers (#1378) * Add warning for few workers * Fix style issue * Update CHANGELOG.md * Update test * formatting * formatting Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + pytorch_lightning/trainer/data_loading.py | 15 ++++++-- tests/trainer/test_dataloaders.py | 42 +++++++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4435c4e6ccdff..795233c6c908f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) - Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269)) - Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) +- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378)) ### Changed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fe1adf75c3fdc..66b83fd4bfd00 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,8 +1,9 @@ +import warnings from abc import ABC, abstractmethod from typing import Union, List, Tuple, Callable import torch.distributed as torch_distrib -from torch.utils.data import SequentialSampler, DataLoader +from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule @@ -73,6 +74,12 @@ def _percent_range_check(self, name: str) -> None: if not 0. <= value <= 1.: raise ValueError(msg) + def _worker_check(self, dataloader: DataLoader, name: str) -> None: + if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2: + warnings.warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.' + ' Consider increasing the value of the `num_workers` argument`' + ' in the `DataLoader` init to improve performance.') + def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader @@ -112,11 +119,13 @@ def reset_train_dataloader(self, model: LightningModule) -> None: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) + self.num_training_batches = 0 # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) + self._worker_check(self.train_dataloader, 'train dataloader') self._percent_range_check('train_percent_check') if not _has_len(self.train_dataloader): @@ -176,10 +185,10 @@ def _reset_eval_dataloader(self, model: LightningModule, # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: - for dataloader in dataloaders: + for i, dataloader in enumerate(dataloaders): + self._worker_check(dataloader, f'{mode} dataloader {i}') if not _has_len(dataloader): num_batches = float('inf') - break percent_check = getattr(self, f'{mode}_percent_check') diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d0da0044f217e..408774430c398 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -15,6 +15,7 @@ LightValStepFitMultipleDataloadersMixin, LightValStepFitSingleDataloaderMixin, LightTrainDataloader, + LightValidationDataloader, LightInfTrainDataloader, LightInfValDataloader, LightInfTestDataloader, @@ -485,6 +486,47 @@ class CurrentTestModel( trainer.fit(model) +def test_warning_with_few_workers(tmpdir): + """ Test that error is raised if dataloader with only a few workers is used """ + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + LightTestFitSingleTestDataloadersMixin, + LightEmptyTestStep, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) + + trainer = Trainer(**trainer_options) + + # fit model + with pytest.warns(UserWarning, match='train'): + trainer.fit(model, **fit_options) + + with pytest.warns(UserWarning, match='val'): + trainer.fit(model, **fit_options) + + with pytest.warns(UserWarning, match='test'): + trainer.test() + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_reinit_for_subclass():