Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing test for "multiple dataloader + percent_check fix" #2226

Merged
merged 31 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ad833fa
Init fix num_batches
rohitgr7 May 21, 2020
7de06cc
Fix num_batches in case of multiple dataloaders
rohitgr7 May 21, 2020
a9b996d
Apply suggestions from code review
awaelchli May 21, 2020
da4fbcc
Changes based on suggestions
rohitgr7 May 21, 2020
13bb5c0
Flake8
rohitgr7 May 21, 2020
5241bcb
Add test to check num_batches
rohitgr7 May 22, 2020
98b59a4
generalize dataloader percent check test
awaelchli May 26, 2020
e0bd090
fix formatting
awaelchli May 26, 2020
f61c2c5
remove hparams
rohitgr7 May 26, 2020
1aee122
tests
rohitgr7 Jun 12, 2020
c96e00a
CHANGELOG
rohitgr7 Jun 12, 2020
476a693
Update CHANGELOG.md
Borda Jun 15, 2020
a7f751e
max_batches can be int
rohitgr7 Jun 15, 2020
4e033aa
conflict and rebase
rohitgr7 Jun 15, 2020
46ab176
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 17, 2020
06d5ab0
add back the test
awaelchli Jun 17, 2020
29f222c
update changelog
awaelchli Jun 17, 2020
db84ca9
Update CHANGELOG.md
Borda Jun 17, 2020
4cc7223
Fix num batches in case of multiple dataloaders and percent_check (#1…
rohitgr7 Jun 18, 2020
c9f61b4
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 18, 2020
bd4cb16
missing union
awaelchli Jun 18, 2020
3ad692f
doc update suggestion by @rohitgr7
awaelchli Jun 18, 2020
28a8fe8
extend test
awaelchli Jun 18, 2020
d683a40
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 18, 2020
f866dce
changelog
awaelchli Jun 18, 2020
2320661
docs add note about multiple loaders
awaelchli Jun 18, 2020
0914464
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 20, 2020
2bf631a
update changelog
awaelchli Jun 20, 2020
005b45e
remove unused variable
awaelchli Jun 20, 2020
e9acd33
Merge branch 'master' into fix_num_batches_new
awaelchli Jun 20, 2020
3450180
Merge branch 'master' into num_batches_missing_test
Borda Jun 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `_auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))
- Fixed mistake in parameters' grad norm tracking ([#2012](https://github.com/PyTorchLightning/pytorch-lightning/pull/2012))
- Fixed CPU and hanging GPU crash ([#2118](https://github.com/PyTorchLightning/pytorch-lightning/pull/2118))

- Fixed number batches in case of multiple dataloaders and `limit_{*}_batches` ([#1920](https://github.com/PyTorchLightning/pytorch-lightning/pull/1920), [#2226](https://github.com/PyTorchLightning/pytorch-lightning/pull/2226))
- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773))

## [0.7.6] - 2020-05-16
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class TrainerTrainLoopMixin(ABC):
check_val_every_n_epoch: ...
num_training_batches: int
val_check_batch: ...
num_val_batches: int
num_val_batches: List[int]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
disable_validation: bool
fast_dev_run: ...
accumulation_scheduler: ...
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class TestDataloaderVariations(ABC):

@abstractmethod
def dataloader(self, train: bool):
def dataloader(self, *args, **kwargs):
"""placeholder"""

def test_dataloader(self):
Expand All @@ -15,6 +15,11 @@ def test_dataloader(self):
def test_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

def test_dataloader__multiple_mixed_length(self):
lengths = [50, 30, 40]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

def test_dataloader__empty(self):
return None

Expand Down
4 changes: 2 additions & 2 deletions tests/base/model_test_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class TestEpochEndVariations(ABC):

def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Called at the end of test epoch to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_epoch_end(self, outputs):

def test_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs
Called at the end of test epoch to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/base/model_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class ModelTemplateData:
hparams: ...

def dataloader(self, train):
dataset = TrialMNIST(root=self.data_root, train=train, download=True)
def dataloader(self, train: bool, num_samples: int = 100):
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)

loader = DataLoader(
dataset=dataset,
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
class ValDataloaderVariations(ABC):

@abstractmethod
def dataloader(self, train: bool):
def dataloader(self, *args, **kwargs):
"""placeholder"""

def val_dataloader(self):
return self.dataloader(train=False)

def val_dataloader__multiple_mixed_length(self):
lengths = [100, 30]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

def val_dataloader__multiple(self):
return [self.dataloader(train=False),
self.dataloader(train=False)]
Expand Down
2 changes: 1 addition & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _mean(res, key):
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results

def validation_epoch_end_multiple_dataloaders(self, outputs):
def validation_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs

Expand Down
50 changes: 48 additions & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_multiple_val_dataloader(tmpdir):
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end_multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders

# fit model
trainer = Trainer(
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):

model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end_multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders

# train, multiple val and multiple test passed to fit
Expand All @@ -251,6 +251,52 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(0, 0, 0.5),
pytest.param(1.0, 1.0, 1.0),
pytest.param(0.2, 0.4, 0.4),
]
)
def test_dataloaders_with_limit_batches_percent(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
expected_val_batches = [
int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders
]
assert trainer.num_training_batches == expected_train_batches, \
f'train_percent_check not working with train_dataloaders, got {trainer.num_training_batches}'

assert trainer.num_val_batches == expected_val_batches, \
f'val_percent_check not working with val_dataloaders, got {trainer.num_val_batches}'

trainer.test(ckpt_path=None)
expected_test_batches = [
int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders
]
assert trainer.num_test_batches == expected_test_batches, \
f'test_percent_check not working with test_dataloaders, got {trainer.num_test_batches}'


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
"""Verify that dataloaders can be passed to fit"""
Expand Down