Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent flickering progress bar #6009

Merged
merged 14 commits into from
Feb 17, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))


- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009))


- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))


Expand Down
29 changes: 27 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,37 @@
from typing import Optional, Union

if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
from tqdm.auto import tqdm as _tqdm
else:
from tqdm import tqdm
from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks import Callback

_PAD_SIZE = 5


class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
"""

@staticmethod
def format_num(n) -> str:
""" Add additional padding to the formatted numbers """
should_be_padded = isinstance(n, (float, str))
if not isinstance(n, str):
n = _tqdm.format_num(n)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if should_be_padded and 'e' not in n:
if '.' not in n and len(n) < _PAD_SIZE:
try:
_ = float(n)
except ValueError:
return n
n += '.'
n += "0" * (_PAD_SIZE - len(n))
return n


class ProgressBarBase(Callback):
r"""
Expand Down
10 changes: 10 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import tqdm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

Expand Down Expand Up @@ -371,3 +372,12 @@ def training_step(self, batch, batch_idx):
pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")


@pytest.mark.parametrize(
"input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'],
['10000', '10000'], ['abc', 'abc']]
)
def test_tqdm_format_num(input_num, expected):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected