diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e0148d60e6cbd..b5e2fe9554b73 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -127,7 +127,7 @@ from abc import ABC, abstractmethod import torch -import tqdm +from tqdm.auto import tqdm from pytorch_lightning.utilities.debugging import MisconfigurationException @@ -293,9 +293,9 @@ def run_evaluation(self, test=False): # main progress bar will already be closed when testing so initial position is free position = 2 * self.process_position + (not test) desc = 'Testing' if test else 'Validating' - pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, - disable=not self.show_progress_bar, dynamic_ncols=True, - unit='batch', file=sys.stdout) + pbar = tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch', file=sys.stdout) setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) # run evaluation diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c2add79416021..9441f63e8d5a2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -import tqdm +from tqdm.auto import tqdm from torch.optim.optimizer import Optimizer from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin @@ -808,13 +808,13 @@ def run_pretrain_routine(self, model): ref_model.on_train_start() if not self.disable_validation and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check - pbar = tqdm.tqdm(desc='Validation sanity check', + pbar = tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), leave=False, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') self.main_progress_bar = pbar # dummy validation progress bar - self.val_progress_bar = tqdm.tqdm(disable=True) + self.val_progress_bar = tqdm(disable=True) eval_results = self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, False) @@ -828,9 +828,9 @@ def run_pretrain_routine(self, model): self.early_stop_callback.check_metrics(callback_metrics) # init progress bar - pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', - file=sys.stdout) + pbar = tqdm(leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', + file=sys.stdout) self.main_progress_bar = pbar # clear cache before training