From 68fd3086f1f95c191caabfa9f4c68d5b20c5caf6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Feb 2021 20:01:51 +0100 Subject: [PATCH] Prevent flickering progress bar (#6009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add padding * fix * fix * Update pytorch_lightning/callbacks/progress.py Co-authored-by: Carlos MocholĂ­ * updated based on suggestion * changelog * add test * fix pep8 * resolve test * fix code format Co-authored-by: Carlos MocholĂ­ Co-authored-by: tchaton --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/progress.py | 29 +++++++++++++++++++++++-- tests/callbacks/test_progress_bar.py | 10 +++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57293d6ea8b4e..8ab18a66d37f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009)) + + - Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 7de7982b4a2de..3f401669c351e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -26,12 +26,37 @@ from typing import Optional, Union if importlib.util.find_spec('ipywidgets') is not None: - from tqdm.auto import tqdm + from tqdm.auto import tqdm as _tqdm else: - from tqdm import tqdm + from tqdm import tqdm as _tqdm from pytorch_lightning.callbacks import Callback +_PAD_SIZE = 5 + + +class tqdm(_tqdm): + """ + Custom tqdm progressbar where we append 0 to floating points/strings to + prevent the progress bar from flickering + """ + + @staticmethod + def format_num(n) -> str: + """ Add additional padding to the formatted numbers """ + should_be_padded = isinstance(n, (float, str)) + if not isinstance(n, str): + n = _tqdm.format_num(n) + if should_be_padded and 'e' not in n: + if '.' not in n and len(n) < _PAD_SIZE: + try: + _ = float(n) + except ValueError: + return n + n += '.' + n += "0" * (_PAD_SIZE - len(n)) + return n + class ProgressBarBase(Callback): r""" diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8398aec88fe68..9ec48008512fb 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.progress import tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -371,3 +372,12 @@ def training_step(self, batch, batch_idx): pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + + +@pytest.mark.parametrize( + "input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'], + ['10000', '10000'], ['abc', 'abc']] +) +def test_tqdm_format_num(input_num, expected): + """ Check that the specialized tqdm.format_num appends 0 to floats and strings """ + assert tqdm.format_num(input_num) == expected