From e1e9680dd7e405cbfad80ea358d6352a9ef80873 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 18 Jun 2020 23:50:23 +0200 Subject: [PATCH 01/14] support sanity_val_step=-1 --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d1601caba9392..166a479499964 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1061,7 +1061,7 @@ def run_pretrain_routine(self, model: LightningModule): # run tiny validation (if validation defined) # to make sure program won't crash during val - if not self.disable_validation and self.num_sanity_val_steps > 0: + if not self.disable_validation and self.num_sanity_val_steps: self.reset_val_dataloader(ref_model) # hook and callback @@ -1070,6 +1070,7 @@ def run_pretrain_routine(self, model: LightningModule): num_loaders = len(self.val_dataloaders) max_batches = [self.num_sanity_val_steps] * num_loaders + max_batches = [float('inf') for m in max_batches if m == -1] eval_results = self._evaluate(model, self.val_dataloaders, max_batches, From d375e2adc5906a41a7a2a9ed76f03f7a4d2d2681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 12:13:40 +0200 Subject: [PATCH 02/14] fix list size --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 166a479499964..6e789f2a12aa5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1070,7 +1070,7 @@ def run_pretrain_routine(self, model: LightningModule): num_loaders = len(self.val_dataloaders) max_batches = [self.num_sanity_val_steps] * num_loaders - max_batches = [float('inf') for m in max_batches if m == -1] + max_batches = [float('inf') if m == -1 else m for m in max_batches] eval_results = self._evaluate(model, self.val_dataloaders, max_batches, From 0be420c1577c33c3ca81f9ef6cc12122dce39fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 12:21:28 +0200 Subject: [PATCH 03/14] simplification --- pytorch_lightning/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6e789f2a12aa5..7170c64a89b1b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -379,7 +379,7 @@ def __init__( self.max_steps = max_steps self.min_steps = min_steps - self.num_sanity_val_steps = num_sanity_val_steps + self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." @@ -1070,7 +1070,6 @@ def run_pretrain_routine(self, model: LightningModule): num_loaders = len(self.val_dataloaders) max_batches = [self.num_sanity_val_steps] * num_loaders - max_batches = [float('inf') if m == -1 else m for m in max_batches] eval_results = self._evaluate(model, self.val_dataloaders, max_batches, From 66f4e9157230d4454578a238ddd007d994a4c081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 12:22:30 +0200 Subject: [PATCH 04/14] simplify --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7170c64a89b1b..131e8c893b6cf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1061,7 +1061,7 @@ def run_pretrain_routine(self, model: LightningModule): # run tiny validation (if validation defined) # to make sure program won't crash during val - if not self.disable_validation and self.num_sanity_val_steps: + if not self.disable_validation and self.num_sanity_val_steps > 0: self.reset_val_dataloader(ref_model) # hook and callback From d7b56e9aa7eacf93e10794891c7794464eb7b057 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jun 2020 23:01:52 +0200 Subject: [PATCH 05/14] add test for num_sanity_val_steps=-1 --- pytorch_lightning/callbacks/progress.py | 2 +- tests/trainer/test_trainer.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 358575273f945..4ea9ea1c91b4f 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -295,7 +295,7 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders) + self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5e4cf3544e911..39cca94711aa0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -6,6 +6,7 @@ import types from argparse import Namespace from pathlib import Path +from unittest.mock import patch import cloudpickle import pytest @@ -782,6 +783,24 @@ def test_gpu_choice(tmpdir): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) + +def test_num_sanity_val_steps(tmpdir): + model = EvalModelTemplate() + trainer = Trainer( + num_sanity_val_steps=-1, + # TODO: limit_val_batches influences num_sanity_val_step. Fix it. + # limit_val_batches=0, + max_steps=1, + default_root_dir=tmpdir + ) + assert trainer.num_sanity_val_steps == float('inf') + val_dataloader = model.val_dataloader() + + with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: + trainer.fit(model, val_dataloaders=val_dataloader) + assert mocked.call_count == len(val_dataloader) + + @pytest.mark.parametrize("trainer_kwargs,expected", [ pytest.param( dict(distributed_backend=None, gpus=None), From bd47b70093fab19d55a0476e59f0867f49e8eade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jun 2020 23:05:35 +0200 Subject: [PATCH 06/14] update test --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 39cca94711aa0..3da84282eee9e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -783,8 +783,8 @@ def test_gpu_choice(tmpdir): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) - def test_num_sanity_val_steps(tmpdir): + """ Test that num_sanity_val_steps=-1 runs through all validation data once. """ model = EvalModelTemplate() trainer = Trainer( num_sanity_val_steps=-1, From dd3862d4da07d748cef1dd2aca39f080e4280ed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Jun 2020 00:04:13 +0200 Subject: [PATCH 07/14] update docs --- docs/source/debugging.rst | 2 +- pytorch_lightning/trainer/__init__.py | 7 +++++-- pytorch_lightning/trainer/evaluation_loop.py | 7 ++++--- pytorch_lightning/trainer/trainer.py | 3 ++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 06f9cd4344b43..0339cc6a98e7d 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -129,4 +129,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. testcode:: # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) \ No newline at end of file + trainer = Trainer(num_sanity_val_steps=2) \ No newline at end of file diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 074fd29650642..e56d4df7e21ef 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -599,16 +599,19 @@ def on_train_end(self, trainer, pl_module): Sanity check runs n batches of val before starting the training routine. This catches any bugs in your validation without having to wait for the first validation check. -The Trainer uses 5 steps by default. Turn it off or modify it here. +The Trainer uses 2 steps by default. Turn it off or modify it here. .. testcode:: # default used by the Trainer - trainer = Trainer(num_sanity_val_steps=5) + trainer = Trainer(num_sanity_val_steps=2) # turn it off trainer = Trainer(num_sanity_val_steps=0) + # check all validation data + trainer = Trainer(num_sanity_val_steps=-1) + num_tpu_cores ^^^^^^^^^^^^^ .. warning:: .. deprecated:: 0.7.6 diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 230538ed89c01..dfd73efbb10d8 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -72,15 +72,16 @@ ----------------------------------------- Lightning runs a few steps of validation in the beginning of training. - This avoids crashing in the validation loop sometime deep into a lengthy training loop. +This avoids crashing in the validation loop sometime deep into a lengthy training loop. .. code-block:: python # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) + trainer = Trainer(num_sanity_val_steps=2) -You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check. +You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)` +to check all the validation data. # Testing loop diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 131e8c893b6cf..e59e97315b02b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -278,7 +278,8 @@ def __init__( amp_level: The optimization level to use (O1, O2, etc...). - num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. + num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. + Set it to `-1` to run all batches in all validation dataloaders. Default: 2 truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of From a15fd8de22d80f7c886f0dd9718f17f933984756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Jun 2020 00:14:55 +0200 Subject: [PATCH 08/14] extend tests to multiple dataloaders --- tests/trainer/test_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3da84282eee9e..6ad582121707b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -794,11 +794,11 @@ def test_num_sanity_val_steps(tmpdir): default_root_dir=tmpdir ) assert trainer.num_sanity_val_steps == float('inf') - val_dataloader = model.val_dataloader() + val_dataloaders = model.val_dataloader__multiple() with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: - trainer.fit(model, val_dataloaders=val_dataloader) - assert mocked.call_count == len(val_dataloader) + trainer.fit(model, val_dataloaders=val_dataloaders) + assert mocked.call_count == sum(len(dl) for dl in val_dataloaders) @pytest.mark.parametrize("trainer_kwargs,expected", [ From cef85c3fec3e083582beeee6cb37ae4cfd851226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Jun 2020 00:21:00 +0200 Subject: [PATCH 09/14] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51e56e7c1b18b..bff49dc337d32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246)) + ### Changed - Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289)) From e7c2a8d505312db41fe8d2e9ae5e6945f6684364 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 27 Jun 2020 01:14:16 +0200 Subject: [PATCH 10/14] Update tests/trainer/test_trainer.py Co-authored-by: Jirka Borovec --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 30455537ac5ed..33565ab79c74f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -813,7 +813,7 @@ def test_num_sanity_val_steps(tmpdir): # TODO: limit_val_batches influences num_sanity_val_step. Fix it. # limit_val_batches=0, max_steps=1, - default_root_dir=tmpdir + default_root_dir=tmpdir, ) assert trainer.num_sanity_val_steps == float('inf') val_dataloaders = model.val_dataloader__multiple() From 231c0c6be4e4b28f7992cd216855742681f491b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Jun 2020 23:32:51 +0200 Subject: [PATCH 11/14] improve test --- tests/trainer/test_trainer.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7b147917dc55a..8f9c88744dc96 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -809,15 +809,25 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): assert trainer.tpu_id == expected_tpu_id -def test_num_sanity_val_steps(tmpdir): - """ Test that num_sanity_val_steps=-1 runs through all validation data once. """ +@pytest.mark.parametrize(['limit_val_batches'], [ + pytest.param(0.0), + pytest.param(1), + pytest.param(1.0), + pytest.param(0.3), +]) +def test_num_sanity_val_steps(tmpdir, limit_val_batches): + """ + Test that num_sanity_val_steps=-1 runs through all validation data once. + Makes sure this setting is independent of limit_val_batches. + """ model = EvalModelTemplate() + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders trainer = Trainer( + default_root_dir=tmpdir, num_sanity_val_steps=-1, - # TODO: limit_val_batches influences num_sanity_val_step. Fix it. - # limit_val_batches=0, + limit_val_batches=limit_val_batches, # should have no influence max_steps=1, - default_root_dir=tmpdir, ) assert trainer.num_sanity_val_steps == float('inf') val_dataloaders = model.val_dataloader__multiple() From 6dc3d8bd97a89149be4847e1201e41290df7e786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Jun 2020 00:04:40 +0200 Subject: [PATCH 12/14] refactor the sanity check decision --- pytorch_lightning/callbacks/progress.py | 2 +- pytorch_lightning/trainer/trainer.py | 20 +++++++++++++++----- pytorch_lightning/trainer/training_loop.py | 2 +- tests/trainer/test_trainer.py | 4 ++-- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 4ea9ea1c91b4f..935670be1240e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -95,7 +95,7 @@ def total_val_batches(self) -> int: total_val_batches = 0 if trainer.fast_dev_run and trainer.val_dataloaders is not None: total_val_batches = len(trainer.val_dataloaders) - elif not self.trainer.disable_validation: + elif self.trainer.enable_validation: is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc77748903c89..50cf583509fc4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -345,7 +345,6 @@ def __init__( # training state self.model = None self.testing = False - self.disable_validation = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] self.optimizers = None @@ -798,6 +797,18 @@ def progress_bar_dict(self) -> dict: ref_model = self.model if not self.data_parallel else self.model.module return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) + @property + def disable_validation(self) -> bool: + """ Check if validation is disabled during training. """ + disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ + and not self.fast_dev_run + return disable_validation + + @property + def enable_validation(self) -> bool: + """ Check if we should run validation during training. """ + return not self.disable_validation + # ----------------------------- # MODEL TRAINING # ----------------------------- @@ -1066,13 +1077,12 @@ def run_pretrain_routine(self, model: LightningModule): self.run_evaluation(test_mode=True) return - # check if we should run validation during training - self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ - and not self.fast_dev_run + should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \ + and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val - if not self.disable_validation and self.num_sanity_val_steps > 0: + if should_sanity_check: self.reset_val_dataloader(ref_model) # hook and callback diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1f6a36eb56c89..3d667b9a88f3a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -569,7 +569,7 @@ def should_check_val(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - can_check_val = not self.disable_validation and can_check_epoch + can_check_val = self.enable_validation and can_check_epoch should_check_val = is_val_check_batch or self.should_stop is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8f9c88744dc96..d928b88811c8f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -810,7 +810,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): @pytest.mark.parametrize(['limit_val_batches'], [ - pytest.param(0.0), + pytest.param(0.0), # this should run no sanity checks pytest.param(1), pytest.param(1.0), pytest.param(0.3), @@ -834,7 +834,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: trainer.fit(model, val_dataloaders=val_dataloaders) - assert mocked.call_count == sum(len(dl) for dl in val_dataloaders) + assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders) @pytest.mark.parametrize("trainer_kwargs,expected", [ From 8713ccdd6ee6b0bc245db15133b1b99d5a6da735 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Jul 2020 03:15:01 +0200 Subject: [PATCH 13/14] fix merge --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 810fd55c3dda3..0ee3bf1b69d9a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1198,9 +1198,6 @@ def run_pretrain_routine(self, model: LightningModule): return eval_loop_results - should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \ - and self.limit_val_batches > 0 - # run a few val batches before training starts self._run_sanity_check(ref_model, model) @@ -1215,6 +1212,9 @@ def run_pretrain_routine(self, model: LightningModule): self.train() def _run_sanity_check(self, ref_model, model): + should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \ + and self.limit_val_batches > 0 + # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: From ed62155634fb6ed53f51c08349055885cea9d673 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 23 Jul 2020 06:47:55 -0400 Subject: [PATCH 14/14] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0ee3bf1b69d9a..bee332f831cce 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -886,14 +886,13 @@ def progress_bar_dict(self) -> dict: @property def disable_validation(self) -> bool: """ Check if validation is disabled during training. """ - disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ - and not self.fast_dev_run - return disable_validation + return not self.enable_validation @property def enable_validation(self) -> bool: """ Check if we should run validation during training. """ - return not self.disable_validation + val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0) + return val_loop_enabled or self.fast_dev_run # ----------------------------- # MODEL TRAINING