Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing test for "multiple dataloader + percent_check fix" #2226

Merged
merged 31 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ad833fa
Init fix num_batches
rohitgr7 May 21, 2020
7de06cc
Fix num_batches in case of multiple dataloaders
rohitgr7 May 21, 2020
a9b996d
Apply suggestions from code review
awaelchli May 21, 2020
da4fbcc
Changes based on suggestions
rohitgr7 May 21, 2020
13bb5c0
Flake8
rohitgr7 May 21, 2020
5241bcb
Add test to check num_batches
rohitgr7 May 22, 2020
98b59a4
generalize dataloader percent check test
awaelchli May 26, 2020
e0bd090
fix formatting
awaelchli May 26, 2020
f61c2c5
remove hparams
rohitgr7 May 26, 2020
1aee122
tests
rohitgr7 Jun 12, 2020
c96e00a
CHANGELOG
rohitgr7 Jun 12, 2020
476a693
Update CHANGELOG.md
Borda Jun 15, 2020
a7f751e
max_batches can be int
rohitgr7 Jun 15, 2020
4e033aa
conflict and rebase
rohitgr7 Jun 15, 2020
46ab176
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 17, 2020
06d5ab0
add back the test
awaelchli Jun 17, 2020
29f222c
update changelog
awaelchli Jun 17, 2020
db84ca9
Update CHANGELOG.md
Borda Jun 17, 2020
4cc7223
Fix num batches in case of multiple dataloaders and percent_check (#1…
rohitgr7 Jun 18, 2020
c9f61b4
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 18, 2020
bd4cb16
missing union
awaelchli Jun 18, 2020
3ad692f
doc update suggestion by @rohitgr7
awaelchli Jun 18, 2020
28a8fe8
extend test
awaelchli Jun 18, 2020
d683a40
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 18, 2020
f866dce
changelog
awaelchli Jun 18, 2020
2320661
docs add note about multiple loaders
awaelchli Jun 18, 2020
0914464
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 20, 2020
2bf631a
update changelog
awaelchli Jun 20, 2020
005b45e
remove unused variable
awaelchli Jun 20, 2020
e9acd33
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 20, 2020
3450180
Merge branch 'master' into num_batches_missing_test
Borda Jun 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed number batches in case of multiple dataloaders and `limit_{*}_batches` ([#1920](https://github.com/PyTorchLightning/pytorch-lightning/pull/1920), [#2226](https://github.com/PyTorchLightning/pytorch-lightning/pull/2226))

- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))


Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def on_train_end(self, trainer, pl_module):
# run for only 10 batches
trainer = Trainer(limit_test_batches=10)

In the case of multiple test dataloaders, the limit applies to each dataloader individually.

limit_val_batches
^^^^^^^^^^^^^^^^^

Expand All @@ -473,6 +475,8 @@ def on_train_end(self, trainer, pl_module):
# run for only 10 batches
trainer = Trainer(limit_val_batches=10)

In the case of multiple validation dataloaders, the limit applies to each dataloader individually.

log_gpu_memory
^^^^^^^^^^^^^^
Options:
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,6 @@ def _reset_eval_dataloader(
for i, dataloader in enumerate(dataloaders):
num_batches = 0
self._worker_check(dataloader, f'{mode} dataloader {i}')
if not _has_len(dataloader):
num_batches = float('inf')

# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
Expand Down
21 changes: 16 additions & 5 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@

from abc import ABC, abstractmethod
from pprint import pprint
from typing import Callable, Optional, List
from typing import Callable, Optional, List, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -222,13 +222,20 @@ def reset_test_dataloader(self, *args):
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int], test_mode: bool = False):
def _evaluate(
self,
model: LightningModule,
dataloaders: List[DataLoader],
max_batches: Union[int, List[int]],
test_mode: bool = False
):
"""Run evaluation code.

Args:
model: PT model
dataloaders: list of PT dataloaders
max_batches: List of scalars
model: The model to evaluate.
dataloaders: A list of PyTorch dataloaders.
max_batches: An integer or list of integers with length of the number of dataloaders. Each
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# enable eval mode
Expand All @@ -244,6 +251,10 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int],
# bookkeeping
outputs = []

# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)

# run validation
for dataloader_idx, dataloader in enumerate(dataloaders):
dl_outputs = []
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(

min_steps: Force training for at least these number of steps. Disabled by default (None).

limit_train_batches: How much of training dataset to check.
limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ class TrainerTrainLoopMixin(ABC):
check_val_every_n_epoch: ...
num_training_batches: int
val_check_batch: ...
num_val_batches: int
disable_validation: bool
fast_dev_run: ...
accumulation_scheduler: ...
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TestDataloaderVariations(ABC):

@abstractmethod
def dataloader(self, train: bool):
def dataloader(self, *args, **kwargs):
"""placeholder"""

def test_dataloader(self):
Expand All @@ -19,6 +19,11 @@ def test_dataloader__infinite(self):
def test_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

def test_dataloader__multiple_mixed_length(self):
lengths = [50, 30, 40]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

def test_dataloader__empty(self):
return None

Expand Down
4 changes: 2 additions & 2 deletions tests/base/model_test_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class TestEpochEndVariations(ABC):

def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Called at the end of test epoch to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_epoch_end(self, outputs):

def test_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs
Called at the end of test epoch to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/base/model_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class ModelTemplateData:
hparams: ...

def dataloader(self, train):
dataset = TrialMNIST(root=self.data_root, train=train, download=True)
def dataloader(self, train: bool, num_samples: int = 100):
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)

loader = DataLoader(
dataset=dataset,
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
class ValDataloaderVariations(ABC):

@abstractmethod
def dataloader(self, train: bool):
def dataloader(self, *args, **kwargs):
"""placeholder"""

def val_dataloader(self):
return self.dataloader(train=False)

def val_dataloader__multiple_mixed_length(self):
lengths = [100, 30]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

def val_dataloader__multiple(self):
return [self.dataloader(train=False),
self.dataloader(train=False)]
Expand Down
2 changes: 1 addition & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _mean(res, key):
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results

def validation_epoch_end_multiple_dataloaders(self, outputs):
def validation_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs

Expand Down
4 changes: 2 additions & 2 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.6)}

model = ModelVer0_7(hparams)
Expand All @@ -153,5 +153,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.7)}
79 changes: 77 additions & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_multiple_val_dataloader(tmpdir):
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end_multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders

# fit model
trainer = Trainer(
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):

model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end_multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders

# train, multiple val and multiple test passed to fit
Expand All @@ -251,6 +251,81 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(0, 0, 0.5),
pytest.param(1.0, 1.0, 1.0),
pytest.param(0.2, 0.4, 0.4),
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
expected_val_batches = [
int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders
]
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches

trainer.test(ckpt_path=None)
expected_test_batches = [
int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders
]
assert trainer.num_test_batches == expected_test_batches


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0, 0, 0),
pytest.param(1, 2, 3),
pytest.param(1, 2, 1e50),
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders)


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
"""Verify that dataloaders can be passed to fit"""
Expand Down