Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix num_sanity_val_steps is clipped to limit_val_batches #2917

Merged
merged 11 commits into from
Aug 21, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))

- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))

## [0.8.5] - 2020-07-09

### Added
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def __init__(
self.callback_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_sanity_val_batches = []
self.num_test_batches = []
self.train_dataloader = None
self.test_dataloaders = None
Expand Down Expand Up @@ -462,9 +463,9 @@ def __init__(
self.min_steps = min_steps

if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float("inf")
self.num_sanity_val_steps = float('inf')
else:
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)
self.num_sanity_val_steps = num_sanity_val_steps

self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

Expand Down Expand Up @@ -1224,22 +1225,22 @@ def run_pretrain_routine(self, model: LightningModule):
self.train()

def _run_sanity_check(self, ref_model, model):

using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

# run tiny validation (if validation defined)
# to make sure program won't crash during val
if should_sanity_check:
self.reset_val_dataloader(ref_model)
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]

# hook and callback
self.running_sanity_check = True
self.on_sanity_check_start()

num_loaders = len(self.val_dataloaders)
max_batches = [self.num_sanity_val_steps] * num_loaders
eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False)

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
Expand Down
40 changes: 35 additions & 5 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,28 +907,58 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
pytest.param(0.5),
pytest.param(5),
])
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
num_sanity_val_steps = 4

trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=num_sanity_val_steps,
limit_val_batches=limit_val_batches,
max_steps=1,
)
assert trainer.num_sanity_val_steps == num_sanity_val_steps
val_dataloaders = model.val_dataloader__multiple_mixed_length()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)


@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
])
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps=-1 runs through all validation data once.
Makes sure this setting is independent of limit_val_batches.
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
limited by "limit_val_batches" Trainer argument.
"""
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=-1,
limit_val_batches=limit_val_batches, # should have no influence
limit_val_batches=limit_val_batches,
max_steps=1,
)
assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
assert mocked.call_count == sum(trainer.num_val_batches)


@pytest.mark.parametrize("trainer_kwargs,expected", [
Expand Down