diff --git a/CHANGELOG.md b/CHANGELOG.md index e7bc2f1267ac9..8d758c6d3cb9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) +- Decoupled the progress bar from trainer. It is a callback now and can be customized or even be replaced entirely ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)). - Changed lr schedule step interval behavior to update every backwards pass instead of every forwards pass ([#1476](https://github.com/PyTorchLightning/pytorch-lightning/issues/1476)) @@ -41,6 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecatd `training_tqdm_dict` in favor of `progress_bar_dict` ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)). ### Removed diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index ffb7671b7211d..10323472facd8 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -78,3 +78,9 @@ We successfully extended functionality without polluting our super clean _save_model, _abc_impl, check_monitor_top_k, + +--------- + +.. automodule:: pytorch_lightning.callbacks.progress + :noindex: + :exclude-members: diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index b160cfae3cf90..19c394db4854b 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -19,5 +19,6 @@ Trainer slurm_job_id, tng_tqdm_dic, training_tqdm_dict, + progress_bar_dict, init_optimizers, configure_schedulers diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index b1c47767339cc..c232060ca4ecb 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -2,10 +2,13 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar __all__ = [ 'Callback', 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', + 'ProgressBarBase', + 'ProgressBar', ] diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 9bf576b0c1926..50ea061df615e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -22,6 +22,14 @@ def on_init_end(self, trainer): """Called when the trainer initialization ends, model has not yet been set.""" pass + def on_sanity_check_start(self, trainer, pl_module): + """Called when the validation sanity check starts.""" + pass + + def on_sanity_check_end(self, trainer, pl_module): + """Called when the validation sanity check ends.""" + pass + def on_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" pass @@ -34,6 +42,22 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass + def on_validation_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_validation_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + + def on_test_batch_start(self, trainer, pl_module): + """Called when the test batch begins.""" + pass + + def on_test_batch_end(self, trainer, pl_module): + """Called when the test batch ends.""" + pass + def on_batch_end(self, trainer, pl_module): """Called when the training batch ends.""" pass diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py new file mode 100644 index 0000000000000..a397c0d2c8d56 --- /dev/null +++ b/pytorch_lightning/callbacks/progress.py @@ -0,0 +1,367 @@ +""" +Progress Bars +============= + +Use or override one of the progress bar callbacks. + +""" +import sys + +from tqdm.auto import tqdm + +from pytorch_lightning.callbacks import Callback + + +class ProgressBarBase(Callback): + r""" + The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` + that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. + You should implement your highly custom progress bars with this as the base class. + + Example:: + + class LitProgressBar(ProgressBarBase): + + def __init__(self): + super().__init__() # don't forget this :) + self.enabled = True + + def disable(self): + self.enableenabled = False + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) # don't forget this :) + percent = (self.train_batch_idx / self.total_train_batches) * 100 + sys.stdout.flush() + sys.stdout.write(f'{percent:.01f} percent complete \r') + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + """ + def __init__(self): + + self._trainer = None + self._train_batch_idx = 0 + self._val_batch_idx = 0 + self._test_batch_idx = 0 + + @property + def trainer(self): + return self._trainer + + @property + def train_batch_idx(self) -> int: + """ + The current batch index being processed during training. + Use this to update your progress bar. + """ + return self._train_batch_idx + + @property + def val_batch_idx(self) -> int: + """ + The current batch index being processed during validation. + Use this to update your progress bar. + """ + return self._val_batch_idx + + @property + def test_batch_idx(self) -> int: + """ + The current batch index being processed during testing. + Use this to update your progress bar. + """ + return self._test_batch_idx + + @property + def total_train_batches(self) -> int: + """ + The total number of training batches during training, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + training dataloader is of infinite size. + """ + total_train_batches = 1 if self.trainer.fast_dev_run else self.trainer.num_training_batches + return total_train_batches + + @property + def total_val_batches(self) -> int: + """ + The total number of training batches during validation, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + validation dataloader is of infinite size. + """ + trainer = self.trainer + total_val_batches = 0 + if trainer.fast_dev_run: + total_val_batches = len(trainer.val_dataloaders) + elif not self.trainer.disable_validation: + is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 + total_val_batches = trainer.num_val_batches if is_val_epoch else 0 + return total_val_batches + + @property + def total_test_batches(self) -> int: + """ + The total number of training batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + test dataloader is of infinite size. + """ + if self.trainer.fast_dev_run: + total_test_batches = len(self.trainer.test_dataloaders) + else: + total_test_batches = self.trainer.num_test_batches + return total_test_batches + + def disable(self): + """ + You should provide a way to disable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the + output on processes that have a rank different from 0, e.g., in multi-node training. + """ + raise NotImplementedError + + def enable(self): + """ + You should provide a way to enable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training + routines like the `learning rate finder `_ to temporarily enable and + disable the main progress bar. + """ + raise NotImplementedError + + def on_init_end(self, trainer): + self._trainer = trainer + + def on_train_start(self, trainer, pl_module): + self._train_batch_idx = trainer.batch_idx + + def on_epoch_start(self, trainer, pl_module): + self._train_batch_idx = 0 + + def on_batch_end(self, trainer, pl_module): + self._train_batch_idx += 1 + + def on_validation_start(self, trainer, pl_module): + self._val_batch_idx = 0 + + def on_validation_batch_end(self, trainer, pl_module): + self._val_batch_idx += 1 + + def on_test_start(self, trainer, pl_module): + self._test_batch_idx = 0 + + def on_test_batch_end(self, trainer, pl_module): + self._test_batch_idx += 1 + + +class ProgressBar(ProgressBarBase): + r""" + This is the default progress bar used by Lightning. It prints to `stdout` using the + :mod:`tqdm` package and shows up to four different bars: + + - **sanity check progress:** the progress during the sanity check run + - **main progress:** shows training + validation progress combined. It also accounts for + multiple validation runs during training when + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. + - **validation progress:** only visible during validation; + shows total progress over all validation datasets. + - **test progress:** only active when testing; shows total progress over all test datasets. + + For infinite datasets, the progress bar never ends. + + If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override + specific methods of the callback class and pass your custom implementation to the + :class:`~pytorch_lightning.trainer.trainer.Trainer`: + + Example:: + + class LitProgressBar(ProgressBar): + + def init_validation_tqdm(self): + bar = super().init_validation_tqdm() + bar.set_description('running validation ...') + return bar + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + Args: + refresh_rate: + Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. By default, the + :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress + bar and sets the refresh rate to the value provided to the + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + process_position: + Set this to a value greater than ``0`` to offset the progress bars by this many lines. + This is useful when you have progress bars defined elsewhere and want to show all of them + together. This corresponds to + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + + """ + def __init__(self, refresh_rate: int = 1, process_position: int = 0): + super().__init__() + self._refresh_rate = refresh_rate + self._process_position = process_position + self._enabled = True + self.main_progress_bar = None + self.val_progress_bar = None + self.test_progress_bar = None + + def __getstate__(self): + # can't pickle the tqdm objects + state = self.__dict__.copy() + state['main_progress_bar'] = None + state['val_progress_bar'] = None + state['test_progress_bar'] = None + return state + + @property + def refresh_rate(self) -> int: + return self._refresh_rate + + @property + def process_position(self) -> int: + return self._process_position + + @property + def is_enabled(self) -> bool: + return self._enabled and self.refresh_rate > 0 + + @property + def is_disabled(self) -> bool: + return not self.is_enabled + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + + def init_sanity_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for the validation sanity run. """ + bar = tqdm( + desc='Validation sanity check', + position=(2 * self.process_position), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def init_train_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for training. """ + bar = tqdm( + desc='Training', + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + + def init_validation_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for validation. """ + bar = tqdm( + desc='Validating', + position=(2 * self.process_position + 1), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout + ) + return bar + + def init_test_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for testing. """ + bar = tqdm( + desc='Testing', + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout + ) + return bar + + def on_sanity_check_start(self, trainer, pl_module): + super().on_sanity_check_start(trainer, pl_module) + self.val_progress_bar = self.init_sanity_tqdm() + self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders) + self.main_progress_bar = tqdm(disable=True) # dummy progress bar + + def on_sanity_check_end(self, trainer, pl_module): + super().on_sanity_check_end(trainer, pl_module) + self.main_progress_bar.close() + self.val_progress_bar.close() + + def on_train_start(self, trainer, pl_module): + super().on_train_start(trainer, pl_module) + self.main_progress_bar = self.init_train_tqdm() + + def on_epoch_start(self, trainer, pl_module): + super().on_epoch_start(trainer, pl_module) + total_train_batches = self.total_train_batches + total_val_batches = self.total_val_batches + if total_train_batches != float('inf') and not trainer.fast_dev_run: + # val can be checked multiple times per epoch + val_checks_per_epoch = total_train_batches // trainer.val_check_batch + total_val_batches = total_val_batches * val_checks_per_epoch + total_batches = total_train_batches + total_val_batches + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(convert_inf(total_batches)) + self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: + self.main_progress_bar.update(self.refresh_rate) + self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + + def on_validation_start(self, trainer, pl_module): + super().on_validation_start(trainer, pl_module) + self.val_progress_bar = self.init_validation_tqdm() + self.val_progress_bar.total = convert_inf(self.total_val_batches) + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0: + self.val_progress_bar.update(self.refresh_rate) + self.main_progress_bar.update(self.refresh_rate) + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + self.val_progress_bar.close() + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + self.main_progress_bar.close() + + def on_test_start(self, trainer, pl_module): + super().on_test_start(trainer, pl_module) + self.test_progress_bar = self.init_test_tqdm() + self.test_progress_bar.total = convert_inf(self.total_test_batches) + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0: + self.test_progress_bar.update(self.refresh_rate) + + def on_test_end(self, trainer, pl_module): + super().on_test_end(trainer, pl_module) + self.test_progress_bar.close() + + +def convert_inf(x): + """ The tqdm doesn't support inf values. We have to convert it to None. """ + if x == float('inf'): + return None + return x diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 82be2d53f0ed4..d943a07461821 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1636,7 +1636,7 @@ def on_save_checkpoint(self, checkpoint): """ - def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: + def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" Additional items to be displayed in the progress bar. @@ -1657,3 +1657,18 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: tqdm_dict['v_num'] = self.trainer.logger.version return tqdm_dict + + def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: + """ + Additional items to be displayed in the progress bar. + + Return: + Dictionary with the items to be displayed in the progress bar. + + Warning: + Deprecated since v0.7.3. + Use :meth:`get_progress_bar_dict` instead. + """ + rank_zero_warn("`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" + " and this method will be removed in v1.0.0", DeprecationWarning) + return self.get_progress_bar_dict() diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 0bd161bf0a56d..41b7b1b999a75 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -652,14 +652,16 @@ def on_train_end(self): process_position ^^^^^^^^^^^^^^^^ -Orders the tqdm bar. Useful when running multiple trainers -on the same node. +Orders the progress bar. Useful when running multiple trainers on the same node. Example:: # default used by the Trainer trainer = Trainer(process_position=0) +Note: + This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + profiler ^^^^^^^^ To profile individual steps during training and assist in identifying bottlenecks. @@ -698,6 +700,9 @@ def on_train_end(self): # disable progress bar trainer = Trainer(progress_bar_refresh_rate=0) +Note: + This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + reload_dataloaders_every_epoch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Set to True to reload dataloaders every epoch. diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 7ace0fb20a255..39c67963169e4 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -1,20 +1,25 @@ import os from abc import ABC, abstractmethod -from typing import Union +from typing import Union, List -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping + +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities.exceptions import MisconfigurationException class TrainerCallbackConfigMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class + callbacks: List[Callback] default_root_dir: str logger: Union[LightningLoggerBase, bool] weights_save_path: str ckpt_path: str checkpoint_callback: ModelCheckpoint + progress_bar_refresh_rate: int + process_position: int @property @abstractmethod @@ -101,3 +106,21 @@ def configure_early_stopping(self, early_stop_callback): else: self.early_stop_callback = early_stop_callback self.enable_early_stop = True + + def configure_progress_bar(self): + progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] + if len(progress_bars) > 1: + raise MisconfigurationException( + 'You added multiple progress bar callbacks to the Trainer, but currently only one' + ' progress bar is supported.' + ) + elif len(progress_bars) == 1: + self.progress_bar_callback = progress_bars[0] + elif self.progress_bar_refresh_rate > 0: + self.progress_bar_callback = ProgressBar( + refresh_rate=self.progress_bar_refresh_rate, + process_position=self.process_position, + ) + self.callbacks.append(self.progress_bar_callback) + else: + self.progress_bar_callback = None diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 48d703b84ebe0..37f56e6941039 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Callable +from typing import Callable, List from pytorch_lightning.callbacks import Callback @@ -9,7 +9,7 @@ class TrainerCallbackHookMixin(ABC): def __init__(self): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - self.callbacks: list[Callback] = [] + self.callbacks: List[Callback] = [] self.get_model: Callable = ... def on_init_start(self): @@ -22,6 +22,16 @@ def on_init_end(self): for callback in self.callbacks: callback.on_init_end(self) + def on_sanity_check_start(self): + """Called when the validation sanity check starts.""" + for callback in self.callbacks: + callback.on_sanity_check_start(self, self.get_model()) + + def on_sanity_check_end(self): + """Called when the validation sanity check ends.""" + for callback in self.callbacks: + callback.on_sanity_check_end(self, self.get_model()) + def on_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: @@ -52,6 +62,26 @@ def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end(self, self.get_model()) + def on_validation_batch_start(self): + """Called when the validation batch begins.""" + for callback in self.callbacks: + callback.on_validation_batch_start(self, self.get_model()) + + def on_validation_batch_end(self): + """Called when the validation batch ends.""" + for callback in self.callbacks: + callback.on_validation_batch_end(self, self.get_model()) + + def on_test_batch_start(self): + """Called when the test batch begins.""" + for callback in self.callbacks: + callback.on_test_batch_start(self, self.get_model()) + + def on_test_batch_end(self): + """Called when the test batch ends.""" + for callback in self.callbacks: + callback.on_test_batch_end(self, self.get_model()) + def on_validation_start(self): """Called when the validation loop begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index d9f461d5d039a..2705c4f160464 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -103,6 +103,13 @@ def default_save_path(self, path): " and this method will be removed in v0.8.0", DeprecationWarning) self.default_root_dir = path + @property + def tng_tqdm_dic(self): + """Back compatibility, will be removed in v0.8.0""" + rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" + " and this method will be removed in v0.8.0", DeprecationWarning) + return self.progress_bar_dict + class TrainerDeprecatedAPITillVer0_9(ABC): @@ -121,3 +128,10 @@ def show_progress_bar(self, tf): """Back compatibility, will be removed in v0.9.0""" rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2" " and this method will be removed in v0.9.0", DeprecationWarning) + + @property + def training_tqdm_dict(self): + """Back compatibility, will be removed in v0.9.0""" + rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3" + " and this method will be removed in v0.9.0", DeprecationWarning) + return self.progress_bar_dict diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 736af5cad928f..75a60a45122bc 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -152,6 +152,7 @@ class TrainerDDPMixin(ABC): use_tpu: bool default_root_dir: str use_native_amp: bool + progress_bar_callback: ... @property @abstractmethod @@ -310,9 +311,8 @@ def ddp_train(self, process_idx, model): self.node_rank = 0 # show progressbar only on progress_rank 0 - self.progress_bar_refresh_rate = ( - self.progress_bar_refresh_rate if self.node_rank == 0 and process_idx == 0 else 0 - ) + if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() # determine which process we are and world size if self.use_ddp: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7b79922d82a00..b0d7cae2989f9 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -397,6 +397,7 @@ class TrainerDPMixin(ABC): use_native_amp: bool data_parallel_device_ids: ... logger: Union[LightningLoggerBase, bool] + progress_bar_callback: ... @property @abstractmethod @@ -499,7 +500,8 @@ def tpu_train(self, tpu_core_idx, model): self.tpu_global_core_rank = xm.get_ordinal() # avoid duplicating progress bar - self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if self.tpu_global_core_rank == 0 else 0 + if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() # track current tpu self.current_tpu_idx = tpu_core_idx diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 54d80ac0eb4fe..0320bf35419ea 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -123,14 +123,12 @@ """ -import sys from abc import ABC, abstractmethod from pprint import pprint from typing import Callable import torch from torch.utils.data import DataLoader -from tqdm.auto import tqdm from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel @@ -157,9 +155,6 @@ class TrainerEvaluationLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - test_progress_bar: ... - val_progress_bar: ... - main_progress_bar: ... on_gpu: bool use_ddp: bool use_dp: bool @@ -171,9 +166,8 @@ class TrainerEvaluationLoopMixin(ABC): num_test_batches: int num_val_batches: int fast_dev_run: ... - process_position: ... process_output: ... - training_tqdm_dict: ... + progress_bar_dict: ... proc_rank: int current_epoch: int callback_metrics: ... @@ -181,9 +175,12 @@ class TrainerEvaluationLoopMixin(ABC): val_dataloaders: DataLoader use_tpu: bool reload_dataloaders_every_epoch: ... - progress_bar_refresh_rate: ... # Callback system + on_validation_batch_start: Callable + on_validation_batch_end: Callable + on_test_batch_start: Callable + on_test_batch_end: Callable on_validation_start: Callable on_validation_end: Callable on_test_start: Callable @@ -210,7 +207,7 @@ def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, *args): + def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -265,6 +262,12 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ if batch_idx >= max_batches: break + # callbacks + if test_mode: + self.on_test_batch_start() + else: + self.on_validation_batch_start() + # ----------------- # RUN EVALUATION STEP # ----------------- @@ -280,22 +283,17 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) + self.on_test_batch_end() else: if self.is_overriden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) + self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) - # batch done - if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: - if test_mode: - self.test_progress_bar.update(self.progress_bar_refresh_rate) - else: - self.val_progress_bar.update(self.progress_bar_refresh_rate) - self.main_progress_bar.update(self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} @@ -343,12 +341,6 @@ def run_evaluation(self, test_mode: bool = False): "You called `.test()` without defining model's `.test_step()`." " Please define and try again") - # Validation/Test begin callbacks - if test_mode: - self.on_test_start() - else: - self.on_validation_start() - # hook model = self.get_model() model.on_pre_performance_check() @@ -372,21 +364,18 @@ def run_evaluation(self, test_mode: bool = False): if self.fast_dev_run: max_batches = 1 - # init validation or test progress bar - # main progress bar will already be closed when testing so initial position is free - position = 2 * self.process_position + (not test_mode) - desc = 'Testing' if test_mode else 'Validating' - total = max_batches if max_batches != float('inf') else None - pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position, - disable=not self.progress_bar_refresh_rate, dynamic_ncols=True, file=sys.stdout) - setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar) + # Validation/Test begin callbacks + if test_mode: + self.on_test_start() + else: + self.on_validation_start() # run evaluation eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) # add metrics to prog bar - self.add_tqdm_metrics(prog_bar_metrics) + self.add_progress_bar_metrics(prog_bar_metrics) # log results of test if test_mode and self.proc_rank == 0: @@ -404,16 +393,6 @@ def run_evaluation(self, test_mode: bool = False): # hook model.on_post_performance_check() - # add model specific metrics - if not test_mode: - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) - - # close progress bar - if test_mode: - self.test_progress_bar.close() - else: - self.val_progress_bar.close() - # eventual dataset reloading if test_mode: if self.reload_dataloaders_every_epoch: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index a0599f5871614..3ba662fb5c844 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -16,7 +16,7 @@ class TrainerLoggingMixin(ABC): on_gpu: bool log_gpu_memory: ... logger: Union[LightningLoggerBase, bool] - tqdm_metrics: ... + progress_bar_metrics: ... global_step: int proc_rank: int use_dp: bool @@ -75,12 +75,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() - def add_tqdm_metrics(self, metrics): + def add_progress_bar_metrics(self, metrics): for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() - self.tqdm_metrics[k] = v + self.progress_bar_metrics[k] = v def metrics_to_scalars(self, metrics): new_metrics = {} @@ -98,7 +98,7 @@ def metrics_to_scalars(self, metrics): def process_output(self, output, train=False): """Reduces output according to the training mode. - Separates loss from logging and tqdm metrics + Separates loss from logging and progress bar metrics """ # --------------- # EXTRACT CALLBACK KEYS @@ -119,7 +119,7 @@ def process_output(self, output, train=False): try: progress_output = output['progress_bar'] - # reduce progress metrics for tqdm when using dp + # reduce progress metrics for progress bar when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output(progress_output, num_gpus) @@ -135,7 +135,7 @@ def process_output(self, output, train=False): try: log_output = output['log'] - # reduce progress metrics for tqdm when using dp + # reduce progress metrics for progress bar when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output(log_output, num_gpus) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index ff14e93a49e1b..23eab617001b0 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -123,7 +123,8 @@ def lr_find(self, self.max_steps = num_training # Disable standard progress bar for fit - self.progress_bar_refresh_rate = False + if self.progress_bar_callback: + self.progress_bar_callback.disable() # Accumulation of gradients self.accumulate_grad_batches = num_accumulation_steps @@ -165,6 +166,8 @@ def lr_find(self, # Finish by resetting variables so trainer is ready to fit model self._restore_params(model) + if self.progress_bar_callback: + self.progress_bar_callback.enable() return lr_finder @@ -178,6 +181,7 @@ def _dump_params(self, model): 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, 'accumulate_grad_batches': self.accumulate_grad_batches, 'checkpoint_callback': self.checkpoint_callback, + 'progress_bar_callback': self.progress_bar_callback, 'configure_optimizers': model.configure_optimizers, } @@ -189,6 +193,7 @@ def _restore_params(self, model): self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate'] self.accumulate_grad_batches = self._params['accumulate_grad_batches'] self.checkpoint_callback = self._params['checkpoint_callback'] + self.progress_bar_callback = self._params['progress_bar_callback'] model.configure_optimizers = self._params['configure_optimizers'] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fe69a98ad1205..ce8b9ca090d26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,7 +1,6 @@ import distutils import inspect import os -import sys from argparse import ArgumentParser from typing import Union, Optional, List, Dict, Tuple, Iterable, Any @@ -9,10 +8,9 @@ import torch.distributed as torch_distrib import torch.multiprocessing as mp from torch.utils.data import DataLoader -from tqdm.auto import tqdm from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler @@ -74,9 +72,9 @@ class Trainer( ): DEPRECATED_IN_0_8 = ( 'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', - 'add_row_log_interval', 'nb_sanity_val_steps' + 'add_row_log_interval', 'nb_sanity_val_steps', 'tng_tqdm_dic', ) - DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar') + DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict') def __init__( self, @@ -123,6 +121,7 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, + progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -162,7 +161,7 @@ def __init__( Use `gradient_clip_val` instead. Will remove 0.9.0. - process_position: orders the tqdm bar when running multiple models on same machine. + process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. @@ -190,6 +189,7 @@ def __init__( Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. + Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. overfit_pct: How much of training-, validation-, and test dataset to check. @@ -312,7 +312,6 @@ def __init__( " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip - self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False @@ -390,7 +389,7 @@ def __init__( self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 - self.tqdm_metrics = {} + self.progress_bar_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 @@ -408,7 +407,6 @@ def __init__( self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 - self.total_batches = 0 self.interrupted = False # configure logger @@ -464,12 +462,14 @@ def __init__( # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) - # can't init progress bar here because starting a new process - # means the progress_bar won't survive pickling # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar + self.progress_bar_refresh_rate = progress_bar_refresh_rate + self.progress_bar_callback = None + self.configure_progress_bar() + # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval @@ -632,26 +632,10 @@ def data_parallel(self) -> bool: return self.use_dp or self.use_ddp or self.use_ddp2 @property - def training_tqdm_dict(self) -> dict: - """Read-only for tqdm metrics. - :return: - """ + def progress_bar_dict(self) -> dict: + """ Read-only for progress bar metrics. """ ref_model = self.model if not self.data_parallel else self.model.module - - return dict(**ref_model.get_tqdm_dict(), **self.tqdm_metrics) - - @property - def tng_tqdm_dic(self): - """Read-only for tqdm metrics. - - .. warning:: .. deprecated:: 0.5.0 - - Use `training_tqdm_dict` instead. Will remove 0.8.0. - - """ - rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.training_tqdm_dict + return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) # ----------------------------- # MODEL TRAINING @@ -870,17 +854,12 @@ def run_pretrain_routine(self, model: LightningModule): # run tiny validation (if validation defined) # to make sure program won't crash during val - ref_model.on_sanity_check_start() if not self.disable_validation and self.num_sanity_val_steps > 0: self.reset_val_dataloader(ref_model) - # init progress bars for validation sanity check - pbar = tqdm(desc='Validation sanity check', - total=self.num_sanity_val_steps * len(self.val_dataloaders), - leave=False, position=2 * self.process_position, - disable=not self.progress_bar_refresh_rate, dynamic_ncols=True) - self.main_progress_bar = pbar - # dummy validation progress bar - self.val_progress_bar = tqdm(disable=True) + + # hook and callback + ref_model.on_sanity_check_start() + self.on_sanity_check_start() eval_results = self._evaluate(model, self.val_dataloaders, @@ -888,20 +867,12 @@ def run_pretrain_routine(self, model: LightningModule): False) _, _, _, callback_metrics, _ = self.process_output(eval_results) - # close progress bars - self.main_progress_bar.close() - self.val_progress_bar.close() + self.on_sanity_check_end() # verify that early stop has conditioned on a metric that exists if self.enable_early_stop: self.early_stop_callback._validate_condition_metric(callback_metrics) - # init progress bar - pbar = tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, - file=sys.stdout, smoothing=0) - self.main_progress_bar = pbar - # clear cache before training if self.on_gpu: torch.cuda.empty_cache() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b6756d1e94bbb..33c8339bb85d0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -203,7 +203,6 @@ class TrainerTrainLoopMixin(ABC): num_val_batches: int disable_validation: bool fast_dev_run: ... - main_progress_bar: ... accumulation_scheduler: ... lr_schedulers: ... enable_early_stop: ... @@ -215,7 +214,6 @@ class TrainerTrainLoopMixin(ABC): log_save_interval: float proc_rank: int row_log_interval: float - total_batches: int truncated_bptt_steps: ... optimizers: ... optimizer_frequencies: ... @@ -224,14 +222,13 @@ class TrainerTrainLoopMixin(ABC): model: LightningModule interrupted: bool running_loss: ... - training_tqdm_dict: ... + progress_bar_dict: ... reduce_lr_on_plateau_scheduler: ... profiler: ... batch_idx: int precision: ... train_dataloader: DataLoader reload_dataloaders_every_epoch: bool - progress_bar_refresh_rate: ... max_steps: int min_steps: int total_batch_idx: int @@ -281,7 +278,7 @@ def is_overriden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, *args): + def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -342,18 +339,6 @@ def train(self): model.current_epoch = epoch self.current_epoch = epoch - total_val_batches = 0 - is_val_epoch = False - if not self.disable_validation and self.num_training_batches != float('inf'): - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 - total_val_batches = self.num_val_batches * val_checks_per_epoch - - # total batches includes multiple val checks - self.total_batches = self.num_training_batches + total_val_batches - # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start(self, self.get_model()) @@ -362,22 +347,6 @@ def train(self): window_length=self.accumulate_grad_batches ) - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run - num_iterations = 2 - elif self.total_batches == float('inf'): - # for infinite train or val loader, the progress bar never ends - num_iterations = None - else: - num_iterations = self.total_batches - - # reset progress bar - # .reset() doesn't work on disabled progress bar so we should check - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' - self.main_progress_bar.set_description(desc) - # ----------------- # RUN TNG EPOCH # ----------------- @@ -614,7 +583,7 @@ def optimizer_closure(): all_callback_metrics.append(callback_metrics) # track progress bar metrics - self.add_tqdm_metrics(progress_bar_metrics) + self.add_progress_bar_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) if self.use_horovod: @@ -676,11 +645,6 @@ def optimizer_closure(): if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() - # update progress bar - if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: - self.main_progress_bar.update(self.progress_bar_refresh_rate) - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) - # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} @@ -703,8 +667,6 @@ def _get_optimizers_iterable(self): return [(opt_idx, self.optimizers[opt_idx])] def run_training_teardown(self): - self.main_progress_bar.close() - # Train end events with self.profiler.profile('on_train_end'): # callbacks diff --git a/tests/base/utils.py b/tests/base/utils.py index 1bb485f270ada..1f0d582ed6e01 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -220,7 +220,7 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.5): def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): # this model should get 0.80+ acc - acc = trainer.training_tqdm_dict[key] + acc = trainer.progress_bar_dict[key] assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/trainer/test_callbacks.py b/tests/callbacks/test_callbacks.py similarity index 71% rename from tests/trainer/test_callbacks.py rename to tests/callbacks/test_callbacks.py index 4f77dabbd12a5..4731d4351679e 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -33,10 +33,16 @@ def __init__(self): super().__init__() self.on_init_start_called = False self.on_init_end_called = False + self.on_sanity_check_start_called = False + self.on_sanity_check_end_called = False self.on_epoch_start_called = False self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False + self.on_validation_batch_start_called = False + self.on_validation_batch_end_called = False + self.on_test_batch_start_called = False + self.on_test_batch_end_called = False self.on_train_start_called = False self.on_train_end_called = False self.on_validation_start_called = False @@ -52,6 +58,14 @@ def on_init_end(self, trainer): assert isinstance(trainer, Trainer) self.on_init_end_called = True + def on_sanity_check_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_sanity_check_start_called = True + + def on_sanity_check_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_sanity_check_end_called = True + def on_epoch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_start_called = True @@ -68,6 +82,22 @@ def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True + def on_validation_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_batch_start_called = True + + def on_validation_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_batch_end_called = True + + def on_test_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_batch_start_called = True + + def on_test_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_batch_end_called = True + def on_train_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_start_called = True @@ -104,10 +134,16 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_init_start_called assert not test_callback.on_init_end_called + assert not test_callback.on_sanity_check_start_called + assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called @@ -121,10 +157,16 @@ def on_test_end(self, trainer, pl_module): assert trainer.callbacks[0] == test_callback assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert not test_callback.on_sanity_check_start_called + assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called @@ -136,21 +178,36 @@ def on_test_end(self, trainer, pl_module): assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert test_callback.on_sanity_check_start_called + assert test_callback.on_sanity_check_end_called assert test_callback.on_epoch_start_called assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called + assert test_callback.on_validation_batch_start_called + assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called assert test_callback.on_train_end_called assert test_callback.on_validation_start_called assert test_callback.on_validation_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called - trainer.test() + test_callback = TestCallback() + trainer_options['callbacks'] = [test_callback] + trainer = Trainer(**trainer_options) + trainer.test(model) + assert test_callback.on_test_batch_start_called + assert test_callback.on_test_batch_end_called assert test_callback.on_test_start_called assert test_callback.on_test_end_called + assert not test_callback.on_validation_start_called + assert not test_callback.on_validation_end_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_validation_batch_start_called def test_early_stopping_no_val_step(tmpdir): diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py new file mode 100644 index 0000000000000..7cd5d5435adef --- /dev/null +++ b/tests/callbacks/test_progress_bar.py @@ -0,0 +1,221 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import ( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase +) + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 1), + ([], 2), + ([ProgressBar(refresh_rate=1)], 0), + ([ProgressBar(refresh_rate=2)], 0), + ([ProgressBar(refresh_rate=2)], 1), +]) +def test_progress_bar_on(callbacks, refresh_rate): + """Test different ways the progress bar can be turned on.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + max_epochs=1, + overfit_pct=0.2, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + # Trainer supports only a single progress bar callback at the moment + assert len(progress_bars) == 1 + assert progress_bars[0] is trainer.progress_bar_callback + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 0), + ([], False), + ([ModelCheckpoint('../trainer')], 0), +]) +def test_progress_bar_off(callbacks, refresh_rate): + """Test different ways the progress bar can be turned off.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] + assert 0 == len(progress_bars) + assert not trainer.progress_bar_callback + + +def test_progress_bar_misconfiguration(): + """Test that Trainer doesn't accept multiple progress bars.""" + callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint('../trainer')] + with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): + Trainer(callbacks=callbacks) + + +def test_progress_bar_totals(): + """Test that the progress finishes with the correct total steps processed.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + progress_bar_refresh_rate=1, + val_percent_check=1.0, + max_epochs=1, + ) + bar = trainer.progress_bar_callback + assert 0 == bar.total_train_batches + assert 0 == bar.total_val_batches + assert 0 == bar.total_test_batches + + trainer.fit(model) + + # check main progress bar total + n = bar.total_train_batches + m = bar.total_val_batches + assert len(trainer.train_dataloader) == n + assert bar.main_progress_bar.total == n + m + + # check val progress bar total + assert sum(len(loader) for loader in trainer.val_dataloaders) == m + assert bar.val_progress_bar.total == m + + # main progress bar should have reached the end (train batches + val batches) + assert bar.main_progress_bar.n == n + m + assert bar.train_batch_idx == n + + # val progress bar should have reached the end + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + + # check that the test progress bar is off + assert 0 == bar.total_test_batches + assert bar.test_progress_bar is None + + trainer.test(model) + + # check test progress bar total + k = bar.total_test_batches + assert sum(len(loader) for loader in trainer.test_dataloaders) == k + assert bar.test_progress_bar.total == k + + # test progress bar should have reached the end + assert bar.test_progress_bar.n == k + assert bar.test_batch_idx == k + + +def test_progress_bar_fast_dev_run(): + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + fast_dev_run=True, + ) + + progress_bar = trainer.progress_bar_callback + assert 1 == progress_bar.total_train_batches + # total val batches are known only after val dataloaders have reloaded + + trainer.fit(model) + + assert 1 == progress_bar.total_val_batches + assert 1 == progress_bar.train_batch_idx + assert 1 == progress_bar.val_batch_idx + assert 0 == progress_bar.test_batch_idx + + # the main progress bar should display 2 batches (1 train, 1 val) + assert 2 == progress_bar.main_progress_bar.total + assert 2 == progress_bar.main_progress_bar.n + + trainer.test(model) + + # the test progress bar should display 1 batch + assert 1 == progress_bar.test_batch_idx + assert 1 == progress_bar.test_progress_bar.total + assert 1 == progress_bar.test_progress_bar.n + + +@pytest.mark.parametrize('refresh_rate', [0, 1, 50]) +def test_progress_bar_progress_refresh(refresh_rate): + """Test that the three progress bars get correctly updated when using different refresh rates.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + class CurrentProgressBar(ProgressBar): + + train_batches_seen = 0 + val_batches_seen = 0 + test_batches_seen = 0 + + def on_batch_start(self, trainer, pl_module): + super().on_batch_start(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + 1 + if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: + assert self.main_progress_bar.n == self.train_batch_idx + self.train_batches_seen += 1 + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if not self.is_disabled and self.val_batch_idx % self.refresh_rate == 0: + assert self.val_progress_bar.n == self.val_batch_idx + self.val_batches_seen += 1 + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if not self.is_disabled and self.test_batch_idx % self.refresh_rate == 0: + assert self.test_progress_bar.n == self.test_batch_idx + self.test_batches_seen += 1 + + progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) + trainer = Trainer( + callbacks=[progress_bar], + progress_bar_refresh_rate=101, # should not matter if custom callback provided + train_percent_check=1.0, + num_sanity_val_steps=2, + max_epochs=3, + ) + assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate + + trainer.fit(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + + trainer.test(model) + assert progress_bar.test_batches_seen == progress_bar.total_test_batches