Skip to content

Commit

Permalink
Progress bar callback (#1450)
Browse files Browse the repository at this point in the history
* squash and rebase

sanity check hooks


sanity check callback hook finish


moved core progress bar functionality into callback


wip


remove duplicate merge


clean up


imports


docs


sanity check progress bar main


sanity


move callback calls


init progrss bar callback


configuration and docs


changelog


rate decorator


pass process_position


disable on rank > 0


position index


is_enabled


remove decorator


refactor init tqdm bars


callback method ordering 


cannot reset when disabled


sequence -> list


default values


fix has no attr _time() 


move on_val_end to proper place


fix the pickle issue


update warning


properties


check for None


remove old comment


switch order


pull out non-tqdm functionality into base class


documentation for the base class


docs


fix refresh rate issue in validation


restrict type hint of trainer arg


more docs


update trainer docs


rst docs


fix lines too long


fix test


add missing type hints


fix typo


move docstring to __init__ solves doctest failures


remove doctest :(( can't fix the pickle error


fix example


simplify by saving trainer reference


fix docs errors


move docstring


initial value


multiple val checks per epoch


simpler handling of inf dataset sizes


update inf docs


renamed training_tqdm_dict


rename get_tqdm_dict


rename occurences of tqdm 


update changelog


fix doctest


fix formatting errors


added callback tests


progress bar on off test


more tests for progress bar


weird test fix?


add ignored property


disable default progress bar in LR finder


change enable/disable behavior


trying doctest in CI again


undo doctest pickle error


undo doctest pickle error :((


remove progress_bar_callback Trainer arg and fix tests


restore progress bar after auto lr find


update docs


fix rebase


fix wrong negation

* fix fast dev run total

* more thorough testing

* remove old args

* fix merge

* fix merge

* separate tests

* type hint total batches

* reduce if

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* is_disabled

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* is_enabled

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* rename enabled/disabled

* move deprecated api

* remove duplicated test from merge

* fix rename is_disabled

* newline

* test also testprogress for fast dev run

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 24, 2020
1 parent fe2b666 commit 3e8f2d9
Show file tree
Hide file tree
Showing 22 changed files with 837 additions and 150 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed

- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
- Decoupled the progress bar from trainer. It is a callback now and can be customized or even be replaced entirely ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)).

- Changed lr schedule step interval behavior to update every backwards pass instead of every forwards pass ([#1476](https://github.com/PyTorchLightning/pytorch-lightning/issues/1476))

Expand All @@ -41,6 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecatd `training_tqdm_dict` in favor of `progress_bar_dict` ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)).


### Removed
Expand Down
6 changes: 6 additions & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,9 @@ We successfully extended functionality without polluting our super clean
_save_model,
_abc_impl,
check_monitor_top_k,

---------

.. automodule:: pytorch_lightning.callbacks.progress
:noindex:
:exclude-members:
1 change: 1 addition & 0 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ Trainer
slurm_job_id,
tng_tqdm_dic,
training_tqdm_dict,
progress_bar_dict,
init_optimizers,
configure_schedulers
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar

__all__ = [
'Callback',
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',
'ProgressBarBase',
'ProgressBar',
]
24 changes: 24 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def on_init_end(self, trainer):
"""Called when the trainer initialization ends, model has not yet been set."""
pass

def on_sanity_check_start(self, trainer, pl_module):
"""Called when the validation sanity check starts."""
pass

def on_sanity_check_end(self, trainer, pl_module):
"""Called when the validation sanity check ends."""
pass

def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
pass
Expand All @@ -34,6 +42,22 @@ def on_batch_start(self, trainer, pl_module):
"""Called when the training batch begins."""
pass

def on_validation_batch_start(self, trainer, pl_module):
"""Called when the validation batch begins."""
pass

def on_validation_batch_end(self, trainer, pl_module):
"""Called when the validation batch ends."""
pass

def on_test_batch_start(self, trainer, pl_module):
"""Called when the test batch begins."""
pass

def on_test_batch_end(self, trainer, pl_module):
"""Called when the test batch ends."""
pass

def on_batch_end(self, trainer, pl_module):
"""Called when the training batch ends."""
pass
Expand Down
Loading

0 comments on commit 3e8f2d9

Please sign in to comment.