From 3e8f2d99a9951bfb5fc67a98614128317913be1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Apr 2020 02:46:18 +0200 Subject: [PATCH] Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec * is_disabled Co-Authored-By: Jirka Borovec * is_enabled Co-Authored-By: Jirka Borovec * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + docs/source/callbacks.rst | 6 + docs/source/trainer.rst | 1 + pytorch_lightning/callbacks/__init__.py | 3 + pytorch_lightning/callbacks/base.py | 24 ++ pytorch_lightning/callbacks/progress.py | 367 ++++++++++++++++++ pytorch_lightning/core/lightning.py | 17 +- pytorch_lightning/trainer/__init__.py | 9 +- pytorch_lightning/trainer/callback_config.py | 27 +- pytorch_lightning/trainer/callback_hook.py | 34 +- pytorch_lightning/trainer/deprecated_api.py | 14 + .../trainer/distrib_data_parallel.py | 6 +- pytorch_lightning/trainer/distrib_parts.py | 4 +- pytorch_lightning/trainer/evaluation_loop.py | 61 +-- pytorch_lightning/trainer/logging.py | 12 +- pytorch_lightning/trainer/lr_finder.py | 7 +- pytorch_lightning/trainer/trainer.py | 67 +--- pytorch_lightning/trainer/training_loop.py | 44 +-- tests/base/utils.py | 2 +- tests/callbacks/__init__.py | 0 .../{trainer => callbacks}/test_callbacks.py | 59 ++- tests/callbacks/test_progress_bar.py | 221 +++++++++++ 22 files changed, 837 insertions(+), 150 deletions(-) create mode 100644 pytorch_lightning/callbacks/progress.py create mode 100644 tests/callbacks/__init__.py rename tests/{trainer => callbacks}/test_callbacks.py (71%) create mode 100644 tests/callbacks/test_progress_bar.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e7bc2f1267ac9..8d758c6d3cb9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) +- Decoupled the progress bar from trainer. It is a callback now and can be customized or even be replaced entirely ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)). - Changed lr schedule step interval behavior to update every backwards pass instead of every forwards pass ([#1476](https://github.com/PyTorchLightning/pytorch-lightning/issues/1476)) @@ -41,6 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecatd `training_tqdm_dict` in favor of `progress_bar_dict` ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)). ### Removed diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index ffb7671b7211d..10323472facd8 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -78,3 +78,9 @@ We successfully extended functionality without polluting our super clean _save_model, _abc_impl, check_monitor_top_k, + +--------- + +.. automodule:: pytorch_lightning.callbacks.progress + :noindex: + :exclude-members: diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index b160cfae3cf90..19c394db4854b 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -19,5 +19,6 @@ Trainer slurm_job_id, tng_tqdm_dic, training_tqdm_dict, + progress_bar_dict, init_optimizers, configure_schedulers diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index b1c47767339cc..c232060ca4ecb 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -2,10 +2,13 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar __all__ = [ 'Callback', 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', + 'ProgressBarBase', + 'ProgressBar', ] diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 9bf576b0c1926..50ea061df615e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -22,6 +22,14 @@ def on_init_end(self, trainer): """Called when the trainer initialization ends, model has not yet been set.""" pass + def on_sanity_check_start(self, trainer, pl_module): + """Called when the validation sanity check starts.""" + pass + + def on_sanity_check_end(self, trainer, pl_module): + """Called when the validation sanity check ends.""" + pass + def on_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" pass @@ -34,6 +42,22 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass + def on_validation_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_validation_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + + def on_test_batch_start(self, trainer, pl_module): + """Called when the test batch begins.""" + pass + + def on_test_batch_end(self, trainer, pl_module): + """Called when the test batch ends.""" + pass + def on_batch_end(self, trainer, pl_module): """Called when the training batch ends.""" pass diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py new file mode 100644 index 0000000000000..a397c0d2c8d56 --- /dev/null +++ b/pytorch_lightning/callbacks/progress.py @@ -0,0 +1,367 @@ +""" +Progress Bars +============= + +Use or override one of the progress bar callbacks. + +""" +import sys + +from tqdm.auto import tqdm + +from pytorch_lightning.callbacks import Callback + + +class ProgressBarBase(Callback): + r""" + The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` + that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. + You should implement your highly custom progress bars with this as the base class. + + Example:: + + class LitProgressBar(ProgressBarBase): + + def __init__(self): + super().__init__() # don't forget this :) + self.enabled = True + + def disable(self): + self.enableenabled = False + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) # don't forget this :) + percent = (self.train_batch_idx / self.total_train_batches) * 100 + sys.stdout.flush() + sys.stdout.write(f'{percent:.01f} percent complete \r') + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + """ + def __init__(self): + + self._trainer = None + self._train_batch_idx = 0 + self._val_batch_idx = 0 + self._test_batch_idx = 0 + + @property + def trainer(self): + return self._trainer + + @property + def train_batch_idx(self) -> int: + """ + The current batch index being processed during training. + Use this to update your progress bar. + """ + return self._train_batch_idx + + @property + def val_batch_idx(self) -> int: + """ + The current batch index being processed during validation. + Use this to update your progress bar. + """ + return self._val_batch_idx + + @property + def test_batch_idx(self) -> int: + """ + The current batch index being processed during testing. + Use this to update your progress bar. + """ + return self._test_batch_idx + + @property + def total_train_batches(self) -> int: + """ + The total number of training batches during training, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + training dataloader is of infinite size. + """ + total_train_batches = 1 if self.trainer.fast_dev_run else self.trainer.num_training_batches + return total_train_batches + + @property + def total_val_batches(self) -> int: + """ + The total number of training batches during validation, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + validation dataloader is of infinite size. + """ + trainer = self.trainer + total_val_batches = 0 + if trainer.fast_dev_run: + total_val_batches = len(trainer.val_dataloaders) + elif not self.trainer.disable_validation: + is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 + total_val_batches = trainer.num_val_batches if is_val_epoch else 0 + return total_val_batches + + @property + def total_test_batches(self) -> int: + """ + The total number of training batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + test dataloader is of infinite size. + """ + if self.trainer.fast_dev_run: + total_test_batches = len(self.trainer.test_dataloaders) + else: + total_test_batches = self.trainer.num_test_batches + return total_test_batches + + def disable(self): + """ + You should provide a way to disable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the + output on processes that have a rank different from 0, e.g., in multi-node training. + """ + raise NotImplementedError + + def enable(self): + """ + You should provide a way to enable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training + routines like the `learning rate finder `_ to temporarily enable and + disable the main progress bar. + """ + raise NotImplementedError + + def on_init_end(self, trainer): + self._trainer = trainer + + def on_train_start(self, trainer, pl_module): + self._train_batch_idx = trainer.batch_idx + + def on_epoch_start(self, trainer, pl_module): + self._train_batch_idx = 0 + + def on_batch_end(self, trainer, pl_module): + self._train_batch_idx += 1 + + def on_validation_start(self, trainer, pl_module): + self._val_batch_idx = 0 + + def on_validation_batch_end(self, trainer, pl_module): + self._val_batch_idx += 1 + + def on_test_start(self, trainer, pl_module): + self._test_batch_idx = 0 + + def on_test_batch_end(self, trainer, pl_module): + self._test_batch_idx += 1 + + +class ProgressBar(ProgressBarBase): + r""" + This is the default progress bar used by Lightning. It prints to `stdout` using the + :mod:`tqdm` package and shows up to four different bars: + + - **sanity check progress:** the progress during the sanity check run + - **main progress:** shows training + validation progress combined. It also accounts for + multiple validation runs during training when + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. + - **validation progress:** only visible during validation; + shows total progress over all validation datasets. + - **test progress:** only active when testing; shows total progress over all test datasets. + + For infinite datasets, the progress bar never ends. + + If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override + specific methods of the callback class and pass your custom implementation to the + :class:`~pytorch_lightning.trainer.trainer.Trainer`: + + Example:: + + class LitProgressBar(ProgressBar): + + def init_validation_tqdm(self): + bar = super().init_validation_tqdm() + bar.set_description('running validation ...') + return bar + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + Args: + refresh_rate: + Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. By default, the + :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress + bar and sets the refresh rate to the value provided to the + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + process_position: + Set this to a value greater than ``0`` to offset the progress bars by this many lines. + This is useful when you have progress bars defined elsewhere and want to show all of them + together. This corresponds to + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + + """ + def __init__(self, refresh_rate: int = 1, process_position: int = 0): + super().__init__() + self._refresh_rate = refresh_rate + self._process_position = process_position + self._enabled = True + self.main_progress_bar = None + self.val_progress_bar = None + self.test_progress_bar = None + + def __getstate__(self): + # can't pickle the tqdm objects + state = self.__dict__.copy() + state['main_progress_bar'] = None + state['val_progress_bar'] = None + state['test_progress_bar'] = None + return state + + @property + def refresh_rate(self) -> int: + return self._refresh_rate + + @property + def process_position(self) -> int: + return self._process_position + + @property + def is_enabled(self) -> bool: + return self._enabled and self.refresh_rate > 0 + + @property + def is_disabled(self) -> bool: + return not self.is_enabled + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + + def init_sanity_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for the validation sanity run. """ + bar = tqdm( + desc='Validation sanity check', + position=(2 * self.process_position), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def init_train_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for training. """ + bar = tqdm( + desc='Training', + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + + def init_validation_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for validation. """ + bar = tqdm( + desc='Validating', + position=(2 * self.process_position + 1), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout + ) + return bar + + def init_test_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for testing. """ + bar = tqdm( + desc='Testing', + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout + ) + return bar + + def on_sanity_check_start(self, trainer, pl_module): + super().on_sanity_check_start(trainer, pl_module) + self.val_progress_bar = self.init_sanity_tqdm() + self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders) + self.main_progress_bar = tqdm(disable=True) # dummy progress bar + + def on_sanity_check_end(self, trainer, pl_module): + super().on_sanity_check_end(trainer, pl_module) + self.main_progress_bar.close() + self.val_progress_bar.close() + + def on_train_start(self, trainer, pl_module): + super().on_train_start(trainer, pl_module) + self.main_progress_bar = self.init_train_tqdm() + + def on_epoch_start(self, trainer, pl_module): + super().on_epoch_start(trainer, pl_module) + total_train_batches = self.total_train_batches + total_val_batches = self.total_val_batches + if total_train_batches != float('inf') and not trainer.fast_dev_run: + # val can be checked multiple times per epoch + val_checks_per_epoch = total_train_batches // trainer.val_check_batch + total_val_batches = total_val_batches * val_checks_per_epoch + total_batches = total_train_batches + total_val_batches + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(convert_inf(total_batches)) + self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: + self.main_progress_bar.update(self.refresh_rate) + self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + + def on_validation_start(self, trainer, pl_module): + super().on_validation_start(trainer, pl_module) + self.val_progress_bar = self.init_validation_tqdm() + self.val_progress_bar.total = convert_inf(self.total_val_batches) + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0: + self.val_progress_bar.update(self.refresh_rate) + self.main_progress_bar.update(self.refresh_rate) + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + self.val_progress_bar.close() + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + self.main_progress_bar.close() + + def on_test_start(self, trainer, pl_module): + super().on_test_start(trainer, pl_module) + self.test_progress_bar = self.init_test_tqdm() + self.test_progress_bar.total = convert_inf(self.total_test_batches) + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0: + self.test_progress_bar.update(self.refresh_rate) + + def on_test_end(self, trainer, pl_module): + super().on_test_end(trainer, pl_module) + self.test_progress_bar.close() + + +def convert_inf(x): + """ The tqdm doesn't support inf values. We have to convert it to None. """ + if x == float('inf'): + return None + return x diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 82be2d53f0ed4..d943a07461821 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1636,7 +1636,7 @@ def on_save_checkpoint(self, checkpoint): """ - def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: + def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" Additional items to be displayed in the progress bar. @@ -1657,3 +1657,18 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: tqdm_dict['v_num'] = self.trainer.logger.version return tqdm_dict + + def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: + """ + Additional items to be displayed in the progress bar. + + Return: + Dictionary with the items to be displayed in the progress bar. + + Warning: + Deprecated since v0.7.3. + Use :meth:`get_progress_bar_dict` instead. + """ + rank_zero_warn("`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" + " and this method will be removed in v1.0.0", DeprecationWarning) + return self.get_progress_bar_dict() diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 0bd161bf0a56d..41b7b1b999a75 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -652,14 +652,16 @@ def on_train_end(self): process_position ^^^^^^^^^^^^^^^^ -Orders the tqdm bar. Useful when running multiple trainers -on the same node. +Orders the progress bar. Useful when running multiple trainers on the same node. Example:: # default used by the Trainer trainer = Trainer(process_position=0) +Note: + This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + profiler ^^^^^^^^ To profile individual steps during training and assist in identifying bottlenecks. @@ -698,6 +700,9 @@ def on_train_end(self): # disable progress bar trainer = Trainer(progress_bar_refresh_rate=0) +Note: + This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + reload_dataloaders_every_epoch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Set to True to reload dataloaders every epoch. diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 7ace0fb20a255..39c67963169e4 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -1,20 +1,25 @@ import os from abc import ABC, abstractmethod -from typing import Union +from typing import Union, List -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping + +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities.exceptions import MisconfigurationException class TrainerCallbackConfigMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class + callbacks: List[Callback] default_root_dir: str logger: Union[LightningLoggerBase, bool] weights_save_path: str ckpt_path: str checkpoint_callback: ModelCheckpoint + progress_bar_refresh_rate: int + process_position: int @property @abstractmethod @@ -101,3 +106,21 @@ def configure_early_stopping(self, early_stop_callback): else: self.early_stop_callback = early_stop_callback self.enable_early_stop = True + + def configure_progress_bar(self): + progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] + if len(progress_bars) > 1: + raise MisconfigurationException( + 'You added multiple progress bar callbacks to the Trainer, but currently only one' + ' progress bar is supported.' + ) + elif len(progress_bars) == 1: + self.progress_bar_callback = progress_bars[0] + elif self.progress_bar_refresh_rate > 0: + self.progress_bar_callback = ProgressBar( + refresh_rate=self.progress_bar_refresh_rate, + process_position=self.process_position, + ) + self.callbacks.append(self.progress_bar_callback) + else: + self.progress_bar_callback = None diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 48d703b84ebe0..37f56e6941039 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Callable +from typing import Callable, List from pytorch_lightning.callbacks import Callback @@ -9,7 +9,7 @@ class TrainerCallbackHookMixin(ABC): def __init__(self): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - self.callbacks: list[Callback] = [] + self.callbacks: List[Callback] = [] self.get_model: Callable = ... def on_init_start(self): @@ -22,6 +22,16 @@ def on_init_end(self): for callback in self.callbacks: callback.on_init_end(self) + def on_sanity_check_start(self): + """Called when the validation sanity check starts.""" + for callback in self.callbacks: + callback.on_sanity_check_start(self, self.get_model()) + + def on_sanity_check_end(self): + """Called when the validation sanity check ends.""" + for callback in self.callbacks: + callback.on_sanity_check_end(self, self.get_model()) + def on_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: @@ -52,6 +62,26 @@ def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end(self, self.get_model()) + def on_validation_batch_start(self): + """Called when the validation batch begins.""" + for callback in self.callbacks: + callback.on_validation_batch_start(self, self.get_model()) + + def on_validation_batch_end(self): + """Called when the validation batch ends.""" + for callback in self.callbacks: + callback.on_validation_batch_end(self, self.get_model()) + + def on_test_batch_start(self): + """Called when the test batch begins.""" + for callback in self.callbacks: + callback.on_test_batch_start(self, self.get_model()) + + def on_test_batch_end(self): + """Called when the test batch ends.""" + for callback in self.callbacks: + callback.on_test_batch_end(self, self.get_model()) + def on_validation_start(self): """Called when the validation loop begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index d9f461d5d039a..2705c4f160464 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -103,6 +103,13 @@ def default_save_path(self, path): " and this method will be removed in v0.8.0", DeprecationWarning) self.default_root_dir = path + @property + def tng_tqdm_dic(self): + """Back compatibility, will be removed in v0.8.0""" + rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" + " and this method will be removed in v0.8.0", DeprecationWarning) + return self.progress_bar_dict + class TrainerDeprecatedAPITillVer0_9(ABC): @@ -121,3 +128,10 @@ def show_progress_bar(self, tf): """Back compatibility, will be removed in v0.9.0""" rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2" " and this method will be removed in v0.9.0", DeprecationWarning) + + @property + def training_tqdm_dict(self): + """Back compatibility, will be removed in v0.9.0""" + rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3" + " and this method will be removed in v0.9.0", DeprecationWarning) + return self.progress_bar_dict diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 736af5cad928f..75a60a45122bc 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -152,6 +152,7 @@ class TrainerDDPMixin(ABC): use_tpu: bool default_root_dir: str use_native_amp: bool + progress_bar_callback: ... @property @abstractmethod @@ -310,9 +311,8 @@ def ddp_train(self, process_idx, model): self.node_rank = 0 # show progressbar only on progress_rank 0 - self.progress_bar_refresh_rate = ( - self.progress_bar_refresh_rate if self.node_rank == 0 and process_idx == 0 else 0 - ) + if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() # determine which process we are and world size if self.use_ddp: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7b79922d82a00..b0d7cae2989f9 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -397,6 +397,7 @@ class TrainerDPMixin(ABC): use_native_amp: bool data_parallel_device_ids: ... logger: Union[LightningLoggerBase, bool] + progress_bar_callback: ... @property @abstractmethod @@ -499,7 +500,8 @@ def tpu_train(self, tpu_core_idx, model): self.tpu_global_core_rank = xm.get_ordinal() # avoid duplicating progress bar - self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if self.tpu_global_core_rank == 0 else 0 + if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() # track current tpu self.current_tpu_idx = tpu_core_idx diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 54d80ac0eb4fe..0320bf35419ea 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -123,14 +123,12 @@ """ -import sys from abc import ABC, abstractmethod from pprint import pprint from typing import Callable import torch from torch.utils.data import DataLoader -from tqdm.auto import tqdm from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel @@ -157,9 +155,6 @@ class TrainerEvaluationLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - test_progress_bar: ... - val_progress_bar: ... - main_progress_bar: ... on_gpu: bool use_ddp: bool use_dp: bool @@ -171,9 +166,8 @@ class TrainerEvaluationLoopMixin(ABC): num_test_batches: int num_val_batches: int fast_dev_run: ... - process_position: ... process_output: ... - training_tqdm_dict: ... + progress_bar_dict: ... proc_rank: int current_epoch: int callback_metrics: ... @@ -181,9 +175,12 @@ class TrainerEvaluationLoopMixin(ABC): val_dataloaders: DataLoader use_tpu: bool reload_dataloaders_every_epoch: ... - progress_bar_refresh_rate: ... # Callback system + on_validation_batch_start: Callable + on_validation_batch_end: Callable + on_test_batch_start: Callable + on_test_batch_end: Callable on_validation_start: Callable on_validation_end: Callable on_test_start: Callable @@ -210,7 +207,7 @@ def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, *args): + def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -265,6 +262,12 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ if batch_idx >= max_batches: break + # callbacks + if test_mode: + self.on_test_batch_start() + else: + self.on_validation_batch_start() + # ----------------- # RUN EVALUATION STEP # ----------------- @@ -280,22 +283,17 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) + self.on_test_batch_end() else: if self.is_overriden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) + self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) - # batch done - if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: - if test_mode: - self.test_progress_bar.update(self.progress_bar_refresh_rate) - else: - self.val_progress_bar.update(self.progress_bar_refresh_rate) - self.main_progress_bar.update(self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} @@ -343,12 +341,6 @@ def run_evaluation(self, test_mode: bool = False): "You called `.test()` without defining model's `.test_step()`." " Please define and try again") - # Validation/Test begin callbacks - if test_mode: - self.on_test_start() - else: - self.on_validation_start() - # hook model = self.get_model() model.on_pre_performance_check() @@ -372,21 +364,18 @@ def run_evaluation(self, test_mode: bool = False): if self.fast_dev_run: max_batches = 1 - # init validation or test progress bar - # main progress bar will already be closed when testing so initial position is free - position = 2 * self.process_position + (not test_mode) - desc = 'Testing' if test_mode else 'Validating' - total = max_batches if max_batches != float('inf') else None - pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position, - disable=not self.progress_bar_refresh_rate, dynamic_ncols=True, file=sys.stdout) - setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar) + # Validation/Test begin callbacks + if test_mode: + self.on_test_start() + else: + self.on_validation_start() # run evaluation eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) # add metrics to prog bar - self.add_tqdm_metrics(prog_bar_metrics) + self.add_progress_bar_metrics(prog_bar_metrics) # log results of test if test_mode and self.proc_rank == 0: @@ -404,16 +393,6 @@ def run_evaluation(self, test_mode: bool = False): # hook model.on_post_performance_check() - # add model specific metrics - if not test_mode: - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) - - # close progress bar - if test_mode: - self.test_progress_bar.close() - else: - self.val_progress_bar.close() - # eventual dataset reloading if test_mode: if self.reload_dataloaders_every_epoch: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index a0599f5871614..3ba662fb5c844 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -16,7 +16,7 @@ class TrainerLoggingMixin(ABC): on_gpu: bool log_gpu_memory: ... logger: Union[LightningLoggerBase, bool] - tqdm_metrics: ... + progress_bar_metrics: ... global_step: int proc_rank: int use_dp: bool @@ -75,12 +75,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() - def add_tqdm_metrics(self, metrics): + def add_progress_bar_metrics(self, metrics): for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() - self.tqdm_metrics[k] = v + self.progress_bar_metrics[k] = v def metrics_to_scalars(self, metrics): new_metrics = {} @@ -98,7 +98,7 @@ def metrics_to_scalars(self, metrics): def process_output(self, output, train=False): """Reduces output according to the training mode. - Separates loss from logging and tqdm metrics + Separates loss from logging and progress bar metrics """ # --------------- # EXTRACT CALLBACK KEYS @@ -119,7 +119,7 @@ def process_output(self, output, train=False): try: progress_output = output['progress_bar'] - # reduce progress metrics for tqdm when using dp + # reduce progress metrics for progress bar when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output(progress_output, num_gpus) @@ -135,7 +135,7 @@ def process_output(self, output, train=False): try: log_output = output['log'] - # reduce progress metrics for tqdm when using dp + # reduce progress metrics for progress bar when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output(log_output, num_gpus) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index ff14e93a49e1b..23eab617001b0 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -123,7 +123,8 @@ def lr_find(self, self.max_steps = num_training # Disable standard progress bar for fit - self.progress_bar_refresh_rate = False + if self.progress_bar_callback: + self.progress_bar_callback.disable() # Accumulation of gradients self.accumulate_grad_batches = num_accumulation_steps @@ -165,6 +166,8 @@ def lr_find(self, # Finish by resetting variables so trainer is ready to fit model self._restore_params(model) + if self.progress_bar_callback: + self.progress_bar_callback.enable() return lr_finder @@ -178,6 +181,7 @@ def _dump_params(self, model): 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, 'accumulate_grad_batches': self.accumulate_grad_batches, 'checkpoint_callback': self.checkpoint_callback, + 'progress_bar_callback': self.progress_bar_callback, 'configure_optimizers': model.configure_optimizers, } @@ -189,6 +193,7 @@ def _restore_params(self, model): self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate'] self.accumulate_grad_batches = self._params['accumulate_grad_batches'] self.checkpoint_callback = self._params['checkpoint_callback'] + self.progress_bar_callback = self._params['progress_bar_callback'] model.configure_optimizers = self._params['configure_optimizers'] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fe69a98ad1205..ce8b9ca090d26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,7 +1,6 @@ import distutils import inspect import os -import sys from argparse import ArgumentParser from typing import Union, Optional, List, Dict, Tuple, Iterable, Any @@ -9,10 +8,9 @@ import torch.distributed as torch_distrib import torch.multiprocessing as mp from torch.utils.data import DataLoader -from tqdm.auto import tqdm from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler @@ -74,9 +72,9 @@ class Trainer( ): DEPRECATED_IN_0_8 = ( 'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', - 'add_row_log_interval', 'nb_sanity_val_steps' + 'add_row_log_interval', 'nb_sanity_val_steps', 'tng_tqdm_dic', ) - DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar') + DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict') def __init__( self, @@ -123,6 +121,7 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, + progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -162,7 +161,7 @@ def __init__( Use `gradient_clip_val` instead. Will remove 0.9.0. - process_position: orders the tqdm bar when running multiple models on same machine. + process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. @@ -190,6 +189,7 @@ def __init__( Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. + Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. overfit_pct: How much of training-, validation-, and test dataset to check. @@ -312,7 +312,6 @@ def __init__( " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip - self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False @@ -390,7 +389,7 @@ def __init__( self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 - self.tqdm_metrics = {} + self.progress_bar_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 @@ -408,7 +407,6 @@ def __init__( self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 - self.total_batches = 0 self.interrupted = False # configure logger @@ -464,12 +462,14 @@ def __init__( # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) - # can't init progress bar here because starting a new process - # means the progress_bar won't survive pickling # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar + self.progress_bar_refresh_rate = progress_bar_refresh_rate + self.progress_bar_callback = None + self.configure_progress_bar() + # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval @@ -632,26 +632,10 @@ def data_parallel(self) -> bool: return self.use_dp or self.use_ddp or self.use_ddp2 @property - def training_tqdm_dict(self) -> dict: - """Read-only for tqdm metrics. - :return: - """ + def progress_bar_dict(self) -> dict: + """ Read-only for progress bar metrics. """ ref_model = self.model if not self.data_parallel else self.model.module - - return dict(**ref_model.get_tqdm_dict(), **self.tqdm_metrics) - - @property - def tng_tqdm_dic(self): - """Read-only for tqdm metrics. - - .. warning:: .. deprecated:: 0.5.0 - - Use `training_tqdm_dict` instead. Will remove 0.8.0. - - """ - rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.training_tqdm_dict + return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) # ----------------------------- # MODEL TRAINING @@ -870,17 +854,12 @@ def run_pretrain_routine(self, model: LightningModule): # run tiny validation (if validation defined) # to make sure program won't crash during val - ref_model.on_sanity_check_start() if not self.disable_validation and self.num_sanity_val_steps > 0: self.reset_val_dataloader(ref_model) - # init progress bars for validation sanity check - pbar = tqdm(desc='Validation sanity check', - total=self.num_sanity_val_steps * len(self.val_dataloaders), - leave=False, position=2 * self.process_position, - disable=not self.progress_bar_refresh_rate, dynamic_ncols=True) - self.main_progress_bar = pbar - # dummy validation progress bar - self.val_progress_bar = tqdm(disable=True) + + # hook and callback + ref_model.on_sanity_check_start() + self.on_sanity_check_start() eval_results = self._evaluate(model, self.val_dataloaders, @@ -888,20 +867,12 @@ def run_pretrain_routine(self, model: LightningModule): False) _, _, _, callback_metrics, _ = self.process_output(eval_results) - # close progress bars - self.main_progress_bar.close() - self.val_progress_bar.close() + self.on_sanity_check_end() # verify that early stop has conditioned on a metric that exists if self.enable_early_stop: self.early_stop_callback._validate_condition_metric(callback_metrics) - # init progress bar - pbar = tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, - file=sys.stdout, smoothing=0) - self.main_progress_bar = pbar - # clear cache before training if self.on_gpu: torch.cuda.empty_cache() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b6756d1e94bbb..33c8339bb85d0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -203,7 +203,6 @@ class TrainerTrainLoopMixin(ABC): num_val_batches: int disable_validation: bool fast_dev_run: ... - main_progress_bar: ... accumulation_scheduler: ... lr_schedulers: ... enable_early_stop: ... @@ -215,7 +214,6 @@ class TrainerTrainLoopMixin(ABC): log_save_interval: float proc_rank: int row_log_interval: float - total_batches: int truncated_bptt_steps: ... optimizers: ... optimizer_frequencies: ... @@ -224,14 +222,13 @@ class TrainerTrainLoopMixin(ABC): model: LightningModule interrupted: bool running_loss: ... - training_tqdm_dict: ... + progress_bar_dict: ... reduce_lr_on_plateau_scheduler: ... profiler: ... batch_idx: int precision: ... train_dataloader: DataLoader reload_dataloaders_every_epoch: bool - progress_bar_refresh_rate: ... max_steps: int min_steps: int total_batch_idx: int @@ -281,7 +278,7 @@ def is_overriden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, *args): + def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -342,18 +339,6 @@ def train(self): model.current_epoch = epoch self.current_epoch = epoch - total_val_batches = 0 - is_val_epoch = False - if not self.disable_validation and self.num_training_batches != float('inf'): - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 - total_val_batches = self.num_val_batches * val_checks_per_epoch - - # total batches includes multiple val checks - self.total_batches = self.num_training_batches + total_val_batches - # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start(self, self.get_model()) @@ -362,22 +347,6 @@ def train(self): window_length=self.accumulate_grad_batches ) - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run - num_iterations = 2 - elif self.total_batches == float('inf'): - # for infinite train or val loader, the progress bar never ends - num_iterations = None - else: - num_iterations = self.total_batches - - # reset progress bar - # .reset() doesn't work on disabled progress bar so we should check - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' - self.main_progress_bar.set_description(desc) - # ----------------- # RUN TNG EPOCH # ----------------- @@ -614,7 +583,7 @@ def optimizer_closure(): all_callback_metrics.append(callback_metrics) # track progress bar metrics - self.add_tqdm_metrics(progress_bar_metrics) + self.add_progress_bar_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) if self.use_horovod: @@ -676,11 +645,6 @@ def optimizer_closure(): if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() - # update progress bar - if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: - self.main_progress_bar.update(self.progress_bar_refresh_rate) - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) - # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} @@ -703,8 +667,6 @@ def _get_optimizers_iterable(self): return [(opt_idx, self.optimizers[opt_idx])] def run_training_teardown(self): - self.main_progress_bar.close() - # Train end events with self.profiler.profile('on_train_end'): # callbacks diff --git a/tests/base/utils.py b/tests/base/utils.py index 1bb485f270ada..1f0d582ed6e01 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -220,7 +220,7 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.5): def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): # this model should get 0.80+ acc - acc = trainer.training_tqdm_dict[key] + acc = trainer.progress_bar_dict[key] assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/trainer/test_callbacks.py b/tests/callbacks/test_callbacks.py similarity index 71% rename from tests/trainer/test_callbacks.py rename to tests/callbacks/test_callbacks.py index 4f77dabbd12a5..4731d4351679e 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -33,10 +33,16 @@ def __init__(self): super().__init__() self.on_init_start_called = False self.on_init_end_called = False + self.on_sanity_check_start_called = False + self.on_sanity_check_end_called = False self.on_epoch_start_called = False self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False + self.on_validation_batch_start_called = False + self.on_validation_batch_end_called = False + self.on_test_batch_start_called = False + self.on_test_batch_end_called = False self.on_train_start_called = False self.on_train_end_called = False self.on_validation_start_called = False @@ -52,6 +58,14 @@ def on_init_end(self, trainer): assert isinstance(trainer, Trainer) self.on_init_end_called = True + def on_sanity_check_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_sanity_check_start_called = True + + def on_sanity_check_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_sanity_check_end_called = True + def on_epoch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_start_called = True @@ -68,6 +82,22 @@ def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True + def on_validation_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_batch_start_called = True + + def on_validation_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_batch_end_called = True + + def on_test_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_batch_start_called = True + + def on_test_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_batch_end_called = True + def on_train_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_start_called = True @@ -104,10 +134,16 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_init_start_called assert not test_callback.on_init_end_called + assert not test_callback.on_sanity_check_start_called + assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called @@ -121,10 +157,16 @@ def on_test_end(self, trainer, pl_module): assert trainer.callbacks[0] == test_callback assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert not test_callback.on_sanity_check_start_called + assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called @@ -136,21 +178,36 @@ def on_test_end(self, trainer, pl_module): assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert test_callback.on_sanity_check_start_called + assert test_callback.on_sanity_check_end_called assert test_callback.on_epoch_start_called assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called + assert test_callback.on_validation_batch_start_called + assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called assert test_callback.on_train_end_called assert test_callback.on_validation_start_called assert test_callback.on_validation_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called - trainer.test() + test_callback = TestCallback() + trainer_options['callbacks'] = [test_callback] + trainer = Trainer(**trainer_options) + trainer.test(model) + assert test_callback.on_test_batch_start_called + assert test_callback.on_test_batch_end_called assert test_callback.on_test_start_called assert test_callback.on_test_end_called + assert not test_callback.on_validation_start_called + assert not test_callback.on_validation_end_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_validation_batch_start_called def test_early_stopping_no_val_step(tmpdir): diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py new file mode 100644 index 0000000000000..7cd5d5435adef --- /dev/null +++ b/tests/callbacks/test_progress_bar.py @@ -0,0 +1,221 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import ( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase +) + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 1), + ([], 2), + ([ProgressBar(refresh_rate=1)], 0), + ([ProgressBar(refresh_rate=2)], 0), + ([ProgressBar(refresh_rate=2)], 1), +]) +def test_progress_bar_on(callbacks, refresh_rate): + """Test different ways the progress bar can be turned on.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + max_epochs=1, + overfit_pct=0.2, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + # Trainer supports only a single progress bar callback at the moment + assert len(progress_bars) == 1 + assert progress_bars[0] is trainer.progress_bar_callback + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 0), + ([], False), + ([ModelCheckpoint('../trainer')], 0), +]) +def test_progress_bar_off(callbacks, refresh_rate): + """Test different ways the progress bar can be turned off.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] + assert 0 == len(progress_bars) + assert not trainer.progress_bar_callback + + +def test_progress_bar_misconfiguration(): + """Test that Trainer doesn't accept multiple progress bars.""" + callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint('../trainer')] + with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): + Trainer(callbacks=callbacks) + + +def test_progress_bar_totals(): + """Test that the progress finishes with the correct total steps processed.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + progress_bar_refresh_rate=1, + val_percent_check=1.0, + max_epochs=1, + ) + bar = trainer.progress_bar_callback + assert 0 == bar.total_train_batches + assert 0 == bar.total_val_batches + assert 0 == bar.total_test_batches + + trainer.fit(model) + + # check main progress bar total + n = bar.total_train_batches + m = bar.total_val_batches + assert len(trainer.train_dataloader) == n + assert bar.main_progress_bar.total == n + m + + # check val progress bar total + assert sum(len(loader) for loader in trainer.val_dataloaders) == m + assert bar.val_progress_bar.total == m + + # main progress bar should have reached the end (train batches + val batches) + assert bar.main_progress_bar.n == n + m + assert bar.train_batch_idx == n + + # val progress bar should have reached the end + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + + # check that the test progress bar is off + assert 0 == bar.total_test_batches + assert bar.test_progress_bar is None + + trainer.test(model) + + # check test progress bar total + k = bar.total_test_batches + assert sum(len(loader) for loader in trainer.test_dataloaders) == k + assert bar.test_progress_bar.total == k + + # test progress bar should have reached the end + assert bar.test_progress_bar.n == k + assert bar.test_batch_idx == k + + +def test_progress_bar_fast_dev_run(): + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + fast_dev_run=True, + ) + + progress_bar = trainer.progress_bar_callback + assert 1 == progress_bar.total_train_batches + # total val batches are known only after val dataloaders have reloaded + + trainer.fit(model) + + assert 1 == progress_bar.total_val_batches + assert 1 == progress_bar.train_batch_idx + assert 1 == progress_bar.val_batch_idx + assert 0 == progress_bar.test_batch_idx + + # the main progress bar should display 2 batches (1 train, 1 val) + assert 2 == progress_bar.main_progress_bar.total + assert 2 == progress_bar.main_progress_bar.n + + trainer.test(model) + + # the test progress bar should display 1 batch + assert 1 == progress_bar.test_batch_idx + assert 1 == progress_bar.test_progress_bar.total + assert 1 == progress_bar.test_progress_bar.n + + +@pytest.mark.parametrize('refresh_rate', [0, 1, 50]) +def test_progress_bar_progress_refresh(refresh_rate): + """Test that the three progress bars get correctly updated when using different refresh rates.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + class CurrentProgressBar(ProgressBar): + + train_batches_seen = 0 + val_batches_seen = 0 + test_batches_seen = 0 + + def on_batch_start(self, trainer, pl_module): + super().on_batch_start(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + 1 + if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: + assert self.main_progress_bar.n == self.train_batch_idx + self.train_batches_seen += 1 + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if not self.is_disabled and self.val_batch_idx % self.refresh_rate == 0: + assert self.val_progress_bar.n == self.val_batch_idx + self.val_batches_seen += 1 + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if not self.is_disabled and self.test_batch_idx % self.refresh_rate == 0: + assert self.test_progress_bar.n == self.test_batch_idx + self.test_batches_seen += 1 + + progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) + trainer = Trainer( + callbacks=[progress_bar], + progress_bar_refresh_rate=101, # should not matter if custom callback provided + train_percent_check=1.0, + num_sanity_val_steps=2, + max_epochs=3, + ) + assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate + + trainer.fit(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + + trainer.test(model) + assert progress_bar.test_batches_seen == progress_bar.total_test_batches