Skip to content

Commit

Permalink
Fix num_sanity_val_steps is clipped to limit_val_batches (#2917)
Browse files Browse the repository at this point in the history
* Fix num_sanity_val_steps according to limit_val_steps

* fix test

* add num_sanity_batches

* pep

* update docstring in test

* add more test

* chlog

* update comments and docstring in test

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Adrian Wälchli <adrian.waelchli@inf.unibe.ch>
Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
  • Loading branch information
4 people committed Aug 21, 2020
1 parent bcdb750 commit 7cca385
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))

## [0.9.0] - YYYY-MM-DD

Expand Down Expand Up @@ -121,7 +122,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))


## [0.8.5] - 2020-07-09

### Added
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def __init__(
self.logged_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_sanity_val_batches = []
self.num_test_batches = []
self.train_dataloader = None
self.test_dataloaders = None
Expand Down Expand Up @@ -463,9 +464,9 @@ def __init__(
self.min_steps = min_steps

if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float("inf")
self.num_sanity_val_steps = float('inf')
else:
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)
self.num_sanity_val_steps = num_sanity_val_steps

self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

Expand Down Expand Up @@ -1239,22 +1240,22 @@ def run_pretrain_routine(self, model: LightningModule):
self.train()

def _run_sanity_check(self, ref_model, model):

using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

# run tiny validation (if validation defined)
# to make sure program won't crash during val
if should_sanity_check:
self.reset_val_dataloader(ref_model)
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]

# hook and callback
self.running_sanity_check = True
self.on_sanity_check_start()

num_loaders = len(self.val_dataloaders)
max_batches = [self.num_sanity_val_steps] * num_loaders
eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False)

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
Expand Down
40 changes: 35 additions & 5 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,28 +907,58 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
pytest.param(0.5),
pytest.param(5),
])
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
num_sanity_val_steps = 4

trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=num_sanity_val_steps,
limit_val_batches=limit_val_batches,
max_steps=1,
)
assert trainer.num_sanity_val_steps == num_sanity_val_steps
val_dataloaders = model.val_dataloader__multiple_mixed_length()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)


@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
])
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps=-1 runs through all validation data once.
Makes sure this setting is independent of limit_val_batches.
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
limited by "limit_val_batches" Trainer argument.
"""
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=-1,
limit_val_batches=limit_val_batches, # should have no influence
limit_val_batches=limit_val_batches,
max_steps=1,
)
assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
assert mocked.call_count == sum(trainer.num_val_batches)


@pytest.mark.parametrize("trainer_kwargs,expected", [
Expand Down

0 comments on commit 7cca385

Please sign in to comment.