From deffbaba7ffb16ff57b56fe65f62df761f25fbd6 Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Sun, 26 Jan 2020 15:19:09 +0000 Subject: [PATCH] for #330, use tqdm.auto in trainer (#752) * use tqdm.auto in trainer This will import the ipywidgets version of tqdm if available. This works nicely in notebooks by not filling up the log. In the terminal it will use the same old tqdm. We might also want to consider passing in the tqdm we want as an argument since there may be some edge cases where ipywidgets is available but the interface doesn't support it (e.g. vscode?) or isn't working. In which case people will get a warning message, but may want to configure it themselves. * use `from tqdm.auto` in eval loop * indents --- pytorch_lightning/trainer/evaluation_loop.py | 8 ++++---- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e0148d60e6cbd..b5e2fe9554b73 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -127,7 +127,7 @@ from abc import ABC, abstractmethod import torch -import tqdm +from tqdm.auto import tqdm from pytorch_lightning.utilities.debugging import MisconfigurationException @@ -293,9 +293,9 @@ def run_evaluation(self, test=False): # main progress bar will already be closed when testing so initial position is free position = 2 * self.process_position + (not test) desc = 'Testing' if test else 'Validating' - pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, - disable=not self.show_progress_bar, dynamic_ncols=True, - unit='batch', file=sys.stdout) + pbar = tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch', file=sys.stdout) setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) # run evaluation diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9cc9b55950bae..e1c1464c73cf8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -import tqdm +from tqdm.auto import tqdm from torch.optim.optimizer import Optimizer from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin @@ -850,13 +850,13 @@ def run_pretrain_routine(self, model): ref_model.on_train_start() if not self.disable_validation and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check - pbar = tqdm.tqdm(desc='Validation sanity check', + pbar = tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), leave=False, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') self.main_progress_bar = pbar # dummy validation progress bar - self.val_progress_bar = tqdm.tqdm(disable=True) + self.val_progress_bar = tqdm(disable=True) eval_results = self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, False) @@ -870,9 +870,9 @@ def run_pretrain_routine(self, model): self.early_stop_callback.check_metrics(callback_metrics) # init progress bar - pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', - file=sys.stdout) + pbar = tqdm(leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', + file=sys.stdout) self.main_progress_bar = pbar # clear cache before training