Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support num_sanity_val_steps=-1 #2246

Merged
merged 19 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))

### Changed

- Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
.. testcode::

# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders)
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tqdm does not understand float("inf"), so we have to convert it to None in the case num_sanity_val_steps=inf or dataloader has inf length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but where is this imported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see... just kind of stuck at the bottom lol

self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,19 @@ def on_train_end(self, trainer, pl_module):

Sanity check runs n batches of val before starting the training routine.
This catches any bugs in your validation without having to wait for the first validation check.
The Trainer uses 5 steps by default. Turn it off or modify it here.
The Trainer uses 2 steps by default. Turn it off or modify it here.

.. testcode::

# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)

# turn it off
trainer = Trainer(num_sanity_val_steps=0)

# check all validation data
trainer = Trainer(num_sanity_val_steps=-1)

num_tpu_cores
^^^^^^^^^^^^^
.. warning:: .. deprecated:: 0.7.6
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,16 @@
-----------------------------------------

Lightning runs a few steps of validation in the beginning of training.
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
This avoids crashing in the validation loop sometime deep into a lengthy training loop.

.. code-block:: python

# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)
trainer = Trainer(num_sanity_val_steps=2)


You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check.
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)`
to check all the validation data.

# Testing loop

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def __init__(

amp_level: The optimization level to use (O1, O2, etc...).

num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2

truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

Expand Down Expand Up @@ -376,7 +377,7 @@ def __init__(
self.max_steps = max_steps
self.min_steps = min_steps

self.num_sanity_val_steps = num_sanity_val_steps
self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
Expand Down
19 changes: 19 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import types
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch

import cloudpickle
import pytest
Expand Down Expand Up @@ -804,6 +805,24 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
assert trainer.tpu_id == expected_tpu_id


def test_num_sanity_val_steps(tmpdir):
""" Test that num_sanity_val_steps=-1 runs through all validation data once. """
model = EvalModelTemplate()
trainer = Trainer(
num_sanity_val_steps=-1,
# TODO: limit_val_batches influences num_sanity_val_step. Fix it.
# limit_val_batches=0,
max_steps=1,
default_root_dir=tmpdir
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(len(dl) for dl in val_dataloaders)


@pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param(
dict(distributed_backend=None, gpus=None),
Expand Down