Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable validation when val_percent_check=0 #1251

Merged
merged 5 commits into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))

## [0.7.1] - 2020-03-07

Expand Down
29 changes: 16 additions & 13 deletions docs/source/fast_training.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Fast Training
================
=============
There are multiple options to speed up different parts of the training by choosing to train
on a subset of data. This could be done for speed or debugging purposes.

Check validation every n epochs
-------------------------------------
-------------------------------
If you have a small dataset you might want to check validation every n epochs

.. code-block:: python
Expand All @@ -13,7 +13,7 @@ If you have a small dataset you might want to check validation every n epochs
trainer = Trainer(check_val_every_n_epoch=1)

Force training for min or max epochs
-------------------------------------
------------------------------------
It can be useful to force training for a minimum number of epochs or limit to a max number.

.. seealso::
Expand All @@ -26,7 +26,7 @@ It can be useful to force training for a minimum number of epochs or limit to a


Set validation check frequency within 1 training epoch
-------------------------------------------------------
------------------------------------------------------
For large datasets it's often desirable to check validation multiple times within a training loop.
Pass in a float to check that often within 1 training epoch. Pass in an int k to check every k training batches.
Must use an int if using an IterableDataset.
Expand All @@ -43,7 +43,7 @@ Must use an int if using an IterableDataset.
trainer = Trainer(val_check_interval=100)

Use training data subset
----------------------------------
------------------------
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.

.. code-block:: python
Expand All @@ -54,12 +54,11 @@ If you don't want to check 100% of the training set (for debugging or if it's hu
# check 10% only
trainer = Trainer(train_percent_check=0.1)

.. note:: train_percent_check will be overwritten by overfit_pct if overfit_pct > 0
.. note:: ``train_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0.

Use test data subset
-------------------------------------
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag
test_percent_check will be overwritten by overfit_pct if overfit_pct > 0.
--------------------
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag.

.. code-block:: python

Expand All @@ -69,15 +68,19 @@ test_percent_check will be overwritten by overfit_pct if overfit_pct > 0.
# check 10% only
trainer = Trainer(test_percent_check=0.1)

.. note:: ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0.

Use validation data subset
--------------------------------------------
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag
val_percent_check will be overwritten by overfit_pct if overfit_pct > 0
--------------------------
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag.

.. code-block:: python

# DEFAULT
trainer = Trainer(val_percent_check=1.0)

# check 10% only
trainer = Trainer(val_percent_check=0.1)
trainer = Trainer(val_percent_check=0.1)

.. note:: ``val_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0 and ignored if
``fast_dev_run=True``.
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@ def run_pretrain_routine(self, model: LightningModule):
return

# check if we should run validation during training
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run
self.disable_validation = not (self.is_overriden('validation_step') and self.val_percent_check > 0) \
and not self.fast_dev_run

# run tiny validation (if validation defined)
# to make sure program won't crash during val
Expand Down
50 changes: 50 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LightTrainDataloader,
LightningTestModel,
LightTestMixin,
LightValidationMixin
)


Expand Down Expand Up @@ -156,6 +157,55 @@ class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase):
tutils.assert_ok_model_acc(trainer)


def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
tutils.reset_seed()

class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):

validation_step_invoked = False
validation_end_invoked = False

def validation_step(self, *args, **kwargs):
self.validation_step_invoked = True
return super().validation_step(*args, **kwargs)

def validation_end(self, *args, **kwargs):
self.validation_end_invoked = True
return super().validation_end(*args, **kwargs)

hparams = tutils.get_default_hparams()
model = CurrentModel(hparams)

trainer_options = dict(
show_progress_bar=False,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.0,
fast_dev_run=False,
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'

# check that val_percent_check has no influence when fast_dev_run is turned on
model = CurrentModel(hparams)
trainer_options.update(fast_dev_run=True)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'


def test_single_gpu_batch_parse():
tutils.reset_seed()

Expand Down