Skip to content

Commit

Permalink
Add warning for few workers (#1378)
Browse files Browse the repository at this point in the history
* Add warning for few workers

* Fix style issue

* Update CHANGELOG.md

* Update test

* formatting

* formatting

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
ethanwharris and Borda committed Apr 5, 2020
1 parent fdcf9cd commit b18accc
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378))

### Changed

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable

import torch.distributed as torch_distrib
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
Expand Down Expand Up @@ -73,6 +74,12 @@ def _percent_range_check(self, name: str) -> None:
if not 0. <= value <= 1.:
raise ValueError(msg)

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2:
warnings.warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
' in the `DataLoader` init to improve performance.')

def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

# don't do anything if it's not a dataloader
Expand Down Expand Up @@ -112,11 +119,13 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)

self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self._worker_check(self.train_dataloader, 'train dataloader')
self._percent_range_check('train_percent_check')

if not _has_len(self.train_dataloader):
Expand Down Expand Up @@ -176,10 +185,10 @@ def _reset_eval_dataloader(self, model: LightningModule,
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for dataloader in dataloaders:
for i, dataloader in enumerate(dataloaders):
self._worker_check(dataloader, f'{mode} dataloader {i}')
if not _has_len(dataloader):
num_batches = float('inf')
break

percent_check = getattr(self, f'{mode}_percent_check')

Expand Down
42 changes: 42 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
LightValStepFitMultipleDataloadersMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightValidationDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
Expand Down Expand Up @@ -485,6 +486,47 @@ class CurrentTestModel(
trainer.fit(model)


def test_warning_with_few_workers(tmpdir):
""" Test that error is raised if dataloader with only a few workers is used """
tutils.reset_seed()

class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
LightTestFitSingleTestDataloadersMixin,
LightEmptyTestStep,
TestModelBase,
):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))

trainer = Trainer(**trainer_options)

# fit model
with pytest.warns(UserWarning, match='train'):
trainer.fit(model, **fit_options)

with pytest.warns(UserWarning, match='val'):
trainer.fit(model, **fit_options)

with pytest.warns(UserWarning, match='test'):
trainer.test()


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():

Expand Down

0 comments on commit b18accc

Please sign in to comment.