From 5035ce54749a68f3b4f002724d526ad4e2a6295d Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Wed, 5 Feb 2020 14:24:43 +0300 Subject: [PATCH] Make default tqdm dict overridable (#749) * overridable tqdm_dict * Slim down default tqdm_metrics * gpu fix --- pytorch_lightning/core/lightning.py | 19 +++++++++++++++++ pytorch_lightning/trainer/evaluation_loop.py | 5 ++--- pytorch_lightning/trainer/trainer.py | 22 ++++---------------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0ec7d03ff5da2..66b887fa76374 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1210,6 +1210,25 @@ def on_save_checkpoint(self, checkpoint): """ + def get_tqdm_dict(self): + r""" + Additional items to be displayed in the progress bar. + + Return: + Dictionary with the items to be displayed in the progress bar. + """ + tqdm_dict = { + 'loss': '{:.3f}'.format(self.trainer.avg_loss) + } + + if self.trainer.truncated_bptt_steps is not None: + tqdm_dict['split_idx'] = self.trainer.split_idx + + if self.trainer.logger is not None and self.trainer.logger.version is not None: + tqdm_dict['v_num'] = self.trainer.logger.version + + return tqdm_dict + def load_hparams_from_tags_csv(tags_csv): if not os.path.isfile(tags_csv): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index b5e2fe9554b73..82eb052a95e3c 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -295,7 +295,7 @@ def run_evaluation(self, test=False): desc = 'Testing' if test else 'Validating' pbar = tqdm(desc=desc, total=max_batches, leave=test, position=position, disable=not self.show_progress_bar, dynamic_ncols=True, - unit='batch', file=sys.stdout) + file=sys.stdout) setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) # run evaluation @@ -319,9 +319,8 @@ def run_evaluation(self, test=False): model.on_post_performance_check() # add model specific metrics - tqdm_metrics = self.training_tqdm_dict if not test: - self.main_progress_bar.set_postfix(**tqdm_metrics) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) # close progress bar if test: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0b15100606f6c..6a758f7cca0f0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -681,23 +681,9 @@ def training_tqdm_dict(self): """Read-only for tqdm metrics. :return: """ - tqdm_dict = { - 'loss': '{0:.3f}'.format(self.avg_loss), - 'batch_idx': '{}'.format(self.batch_idx), - } + ref_model = self.model if not self.data_parallel else self.model.module - if self.truncated_bptt_steps is not None: - tqdm_dict['split_idx'] = self.split_idx - - if self.logger is not None and self.logger.version is not None: - tqdm_dict['v_num'] = self.logger.version - - tqdm_dict.update(self.tqdm_metrics) - - if self.on_gpu: - tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device()) - - return tqdm_dict + return dict(**ref_model.get_tqdm_dict(), **self.tqdm_metrics) @property def tng_tqdm_dic(self): @@ -855,7 +841,7 @@ def run_pretrain_routine(self, model): pbar = tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), leave=False, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') + disable=not self.show_progress_bar, dynamic_ncols=True) self.main_progress_bar = pbar # dummy validation progress bar self.val_progress_bar = tqdm(disable=True) @@ -873,7 +859,7 @@ def run_pretrain_routine(self, model): # init progress bar pbar = tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', + disable=not self.show_progress_bar, dynamic_ncols=True, file=sys.stdout) self.main_progress_bar = pbar