diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b4c0f0dc1759..9fcc8e3691955 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535)) +- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246)) + ### Changed diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 18536082399f4..3384d3c9f25cc 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -129,4 +129,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. testcode:: # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) \ No newline at end of file + trainer = Trainer(num_sanity_val_steps=2) \ No newline at end of file diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 472d245f4f1b2..2ff744a03c3ab 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -102,7 +102,7 @@ def total_val_batches(self) -> int: total_val_batches = 0 if trainer.fast_dev_run and trainer.val_dataloaders is not None: total_val_batches = len(trainer.val_dataloaders) - elif not self.trainer.disable_validation: + elif self.trainer.enable_validation: is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches @@ -302,7 +302,7 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders) + self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 793e58671ac7e..b8f02a36c6806 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -603,16 +603,19 @@ def on_train_end(self, trainer, pl_module): Sanity check runs n batches of val before starting the training routine. This catches any bugs in your validation without having to wait for the first validation check. -The Trainer uses 5 steps by default. Turn it off or modify it here. +The Trainer uses 2 steps by default. Turn it off or modify it here. .. testcode:: # default used by the Trainer - trainer = Trainer(num_sanity_val_steps=5) + trainer = Trainer(num_sanity_val_steps=2) # turn it off trainer = Trainer(num_sanity_val_steps=0) + # check all validation data + trainer = Trainer(num_sanity_val_steps=-1) + num_tpu_cores ^^^^^^^^^^^^^ .. warning:: .. deprecated:: 0.7.6 diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c0e3364aa00d3..ed90f19cb9809 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -72,15 +72,16 @@ ----------------------------------------- Lightning runs a few steps of validation in the beginning of training. - This avoids crashing in the validation loop sometime deep into a lengthy training loop. +This avoids crashing in the validation loop sometime deep into a lengthy training loop. .. code-block:: python # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) + trainer = Trainer(num_sanity_val_steps=2) -You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check. +You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)` +to check all the validation data. # Testing loop diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa2de59c90bff..bee332f831cce 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -336,7 +336,8 @@ def __init__( amp_level: The optimization level to use (O1, O2, etc...). - num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. + num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. + Set it to `-1` to run all batches in all validation dataloaders. Default: 2 truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of @@ -408,7 +409,6 @@ def __init__( # training state self.model = None self.testing = False - self.disable_validation = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] self.optimizers = None @@ -488,7 +488,7 @@ def __init__( self.max_steps = max_steps self.min_steps = min_steps - self.num_sanity_val_steps = num_sanity_val_steps + self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." @@ -883,6 +883,17 @@ def progress_bar_dict(self) -> dict: ref_model = self.model if not self.data_parallel else self.model.module return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) + @property + def disable_validation(self) -> bool: + """ Check if validation is disabled during training. """ + return not self.enable_validation + + @property + def enable_validation(self) -> bool: + """ Check if we should run validation during training. """ + val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0) + return val_loop_enabled or self.fast_dev_run + # ----------------------------- # MODEL TRAINING # ----------------------------- @@ -1186,10 +1197,6 @@ def run_pretrain_routine(self, model: LightningModule): return eval_loop_results - # check if we should run validation during training - self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ - and not self.fast_dev_run - # run a few val batches before training starts self._run_sanity_check(ref_model, model) @@ -1204,9 +1211,12 @@ def run_pretrain_routine(self, model: LightningModule): self.train() def _run_sanity_check(self, ref_model, model): + should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \ + and self.limit_val_batches > 0 + # run tiny validation (if validation defined) # to make sure program won't crash during val - if not self.disable_validation and self.num_sanity_val_steps > 0: + if should_sanity_check: self.reset_val_dataloader(ref_model) # hook and callback diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 66a85fc545a13..c84769108a24f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -651,7 +651,7 @@ def should_check_val(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - can_check_val = not self.disable_validation and can_check_epoch + can_check_val = self.enable_validation and can_check_epoch should_check_val = is_val_check_batch or self.should_stop is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e45040f8ceaaf..d21767ab29409 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -6,6 +6,7 @@ import types from argparse import Namespace from pathlib import Path +from unittest.mock import patch import cloudpickle import pytest @@ -807,6 +808,34 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): assert trainer.tpu_id == expected_tpu_id +@pytest.mark.parametrize(['limit_val_batches'], [ + pytest.param(0.0), # this should run no sanity checks + pytest.param(1), + pytest.param(1.0), + pytest.param(0.3), +]) +def test_num_sanity_val_steps(tmpdir, limit_val_batches): + """ + Test that num_sanity_val_steps=-1 runs through all validation data once. + Makes sure this setting is independent of limit_val_batches. + """ + model = EvalModelTemplate() + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=-1, + limit_val_batches=limit_val_batches, # should have no influence + max_steps=1, + ) + assert trainer.num_sanity_val_steps == float('inf') + val_dataloaders = model.val_dataloader__multiple() + + with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: + trainer.fit(model, val_dataloaders=val_dataloaders) + assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders) + + @pytest.mark.parametrize("trainer_kwargs,expected", [ pytest.param( dict(distributed_backend=None, gpus=None),