Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support num_sanity_val_steps=-1 #2246

Merged
merged 19 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))

### Changed


Expand Down
2 changes: 1 addition & 1 deletion docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
.. testcode::

# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def total_val_batches(self) -> int:
total_val_batches = 0
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
total_val_batches = len(trainer.val_dataloaders)
elif not self.trainer.disable_validation:
elif self.trainer.enable_validation:
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0
return total_val_batches
Expand Down Expand Up @@ -302,7 +302,7 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders)
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tqdm does not understand float("inf"), so we have to convert it to None in the case num_sanity_val_steps=inf or dataloader has inf length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but where is this imported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see... just kind of stuck at the bottom lol

self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,19 @@ def on_train_end(self, trainer, pl_module):

Sanity check runs n batches of val before starting the training routine.
This catches any bugs in your validation without having to wait for the first validation check.
The Trainer uses 5 steps by default. Turn it off or modify it here.
The Trainer uses 2 steps by default. Turn it off or modify it here.

.. testcode::

# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)

# turn it off
trainer = Trainer(num_sanity_val_steps=0)

# check all validation data
trainer = Trainer(num_sanity_val_steps=-1)

num_tpu_cores
^^^^^^^^^^^^^
.. warning:: .. deprecated:: 0.7.6
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,16 @@
-----------------------------------------

Lightning runs a few steps of validation in the beginning of training.
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
This avoids crashing in the validation loop sometime deep into a lengthy training loop.

.. code-block:: python

# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)


You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check.
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)`
to check all the validation data.

# Testing loop

Expand Down
26 changes: 18 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def __init__(

amp_level: The optimization level to use (O1, O2, etc...).

num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2

truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

Expand Down Expand Up @@ -408,7 +409,6 @@ def __init__(
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.prepare_data_per_node = prepare_data_per_node
self.lr_schedulers = []
self.optimizers = None
Expand Down Expand Up @@ -488,7 +488,7 @@ def __init__(
self.max_steps = max_steps
self.min_steps = min_steps

self.num_sanity_val_steps = num_sanity_val_steps
self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
Expand Down Expand Up @@ -883,6 +883,17 @@ def progress_bar_dict(self) -> dict:
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics)

@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
return not self.enable_validation

@property
def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0)
return val_loop_enabled or self.fast_dev_run

# -----------------------------
# MODEL TRAINING
# -----------------------------
Expand Down Expand Up @@ -1186,10 +1197,6 @@ def run_pretrain_routine(self, model: LightningModule):

return eval_loop_results

# check if we should run validation during training
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
and not self.fast_dev_run

# run a few val batches before training starts
self._run_sanity_check(ref_model, model)

Expand All @@ -1204,9 +1211,12 @@ def run_pretrain_routine(self, model: LightningModule):
self.train()

def _run_sanity_check(self, ref_model, model):
should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \
and self.limit_val_batches > 0

# run tiny validation (if validation defined)
# to make sure program won't crash during val
if not self.disable_validation and self.num_sanity_val_steps > 0:
if should_sanity_check:
self.reset_val_dataloader(ref_model)

# hook and callback
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def should_check_val(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
can_check_val = self.enable_validation and can_check_epoch
should_check_val = is_val_check_batch or self.should_stop
is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
Expand Down
29 changes: 29 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import types
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch

import cloudpickle
import pytest
Expand Down Expand Up @@ -807,6 +808,34 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
assert trainer.tpu_id == expected_tpu_id


@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
])
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps=-1 runs through all validation data once.
Makes sure this setting is independent of limit_val_batches.
"""
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=-1,
limit_val_batches=limit_val_batches, # should have no influence
max_steps=1,
)
assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)


@pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param(
dict(distributed_backend=None, gpus=None),
Expand Down