Skip to content

Commit

Permalink
support num_sanity_val_steps=-1 (#2246)
Browse files Browse the repository at this point in the history
* support sanity_val_step=-1

* fix list size

* simplification

* simplify

* add test for num_sanity_val_steps=-1

* update test

* update docs

* extend tests to multiple dataloaders

* changelog

* Update tests/trainer/test_trainer.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* improve test

* refactor the sanity check decision

* fix merge

* Update trainer.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
3 people committed Jul 23, 2020
1 parent 62ce00f commit 1e68968
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))

### Changed


Expand Down
2 changes: 1 addition & 1 deletion docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
.. testcode::

# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def total_val_batches(self) -> int:
total_val_batches = 0
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
total_val_batches = len(trainer.val_dataloaders)
elif not self.trainer.disable_validation:
elif self.trainer.enable_validation:
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0
return total_val_batches
Expand Down Expand Up @@ -302,7 +302,7 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders)
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,19 @@ def on_train_end(self, trainer, pl_module):
Sanity check runs n batches of val before starting the training routine.
This catches any bugs in your validation without having to wait for the first validation check.
The Trainer uses 5 steps by default. Turn it off or modify it here.
The Trainer uses 2 steps by default. Turn it off or modify it here.
.. testcode::
# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)
# turn it off
trainer = Trainer(num_sanity_val_steps=0)
# check all validation data
trainer = Trainer(num_sanity_val_steps=-1)
num_tpu_cores
^^^^^^^^^^^^^
.. warning:: .. deprecated:: 0.7.6
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,16 @@
-----------------------------------------
Lightning runs a few steps of validation in the beginning of training.
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
.. code-block:: python
# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check.
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)`
to check all the validation data.
# Testing loop
Expand Down
26 changes: 18 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def __init__(
amp_level: The optimization level to use (O1, O2, etc...).
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
Expand Down Expand Up @@ -408,7 +409,6 @@ def __init__(
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.prepare_data_per_node = prepare_data_per_node
self.lr_schedulers = []
self.optimizers = None
Expand Down Expand Up @@ -488,7 +488,7 @@ def __init__(
self.max_steps = max_steps
self.min_steps = min_steps

self.num_sanity_val_steps = num_sanity_val_steps
self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
Expand Down Expand Up @@ -883,6 +883,17 @@ def progress_bar_dict(self) -> dict:
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics)

@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
return not self.enable_validation

@property
def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0)
return val_loop_enabled or self.fast_dev_run

# -----------------------------
# MODEL TRAINING
# -----------------------------
Expand Down Expand Up @@ -1186,10 +1197,6 @@ def run_pretrain_routine(self, model: LightningModule):

return eval_loop_results

# check if we should run validation during training
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
and not self.fast_dev_run

# run a few val batches before training starts
self._run_sanity_check(ref_model, model)

Expand All @@ -1204,9 +1211,12 @@ def run_pretrain_routine(self, model: LightningModule):
self.train()

def _run_sanity_check(self, ref_model, model):
should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \
and self.limit_val_batches > 0

# run tiny validation (if validation defined)
# to make sure program won't crash during val
if not self.disable_validation and self.num_sanity_val_steps > 0:
if should_sanity_check:
self.reset_val_dataloader(ref_model)

# hook and callback
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def should_check_val(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
can_check_val = self.enable_validation and can_check_epoch
should_check_val = is_val_check_batch or self.should_stop
is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
Expand Down
29 changes: 29 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import types
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch

import cloudpickle
import pytest
Expand Down Expand Up @@ -807,6 +808,34 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
assert trainer.tpu_id == expected_tpu_id


@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
])
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps=-1 runs through all validation data once.
Makes sure this setting is independent of limit_val_batches.
"""
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=-1,
limit_val_batches=limit_val_batches, # should have no influence
max_steps=1,
)
assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)


@pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param(
dict(distributed_backend=None, gpus=None),
Expand Down

0 comments on commit 1e68968

Please sign in to comment.