From cc65f39d971806ae4e5369c0232ee4998b413456 Mon Sep 17 00:00:00 2001 From: YehCF Date: Sat, 7 Dec 2019 21:47:59 +0800 Subject: [PATCH] Fix number of total steps shown in progress bar during sanity validation check when number of validation dataloaders >= 2 (#597) * type: debug Calculate the adequate number of steps to run during sanity_check. This fixes the bug when there are two or more validation dataloaders. - Before: total=self.num_sanity_val_steps - After: total=self.num_sanity_val_steps*len(self.get_val_dataloaders()) * type: refactor Put total=... in the next line * type: refactor run flake8 --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 79c86d1c7712a..66205fed96aac 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -491,7 +491,8 @@ def run_pretrain_routine(self, model): ref_model.on_sanity_check_start() if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check - pbar = tqdm.tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps, + pbar = tqdm.tqdm(desc='Validation sanity check', + total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), leave=False, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') self.main_progress_bar = pbar