Skip to content

Commit

Permalink
Structured results (train loop only. val loop separate PR) (PR 2/5) (#…
Browse files Browse the repository at this point in the history
…2615)

* r

* r

* r

* patched optimizer closure with sr

* patched optimizer closure with sr

* patched optimizer closure with sr

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added autoreduce for train step

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added hooks

* added hooks

* added hooks

* added hooks

* added hooks

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* cache

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

* Update pytorch_lightning/core/step_result.py

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* simple

* finished tests for structured results on train epoch

* simple

* simple

* revert

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update tests/base/deterministic_model.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* finished tests for structured results on train epoch

* docstring typos

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update pytorch_lightning/core/step_result.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update pytorch_lightning/overrides/data_parallel.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
5 people committed Jul 20, 2020
1 parent 816d8cf commit 6d10ac2
Show file tree
Hide file tree
Showing 23 changed files with 1,402 additions and 91 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ jobs:
uses: actions/cache@v1
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
restore-keys: |
${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
- name: Install dependencies
run: |
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import metrics
from pytorch_lightning.core.step_result import TrainResult, EvalResult

__all__ = [
'Trainer',
'LightningModule',
'Callback',
'data_loader',
'seed_everything',
'metrics'
'metrics',
'EvalResult',
'TrainResult'
]

# necessary for regular bolts imports. Skip exception since bolts is not always installed
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,30 @@ def on_sanity_check_end(self, trainer, pl_module):
"""Called when the validation sanity check ends."""
pass

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the train epoch begins."""
pass

def on_train_epoch_end(self, trainer, pl_module):
"""Called when the train epoch ends."""
pass

def on_validation_epoch_start(self, trainer, pl_module):
"""Called when the val epoch begins."""
pass

def on_validation_epoch_end(self, trainer, pl_module):
"""Called when the val epoch ends."""
pass

def on_test_epoch_start(self, trainer, pl_module):
"""Called when the test epoch begins."""
pass

def on_test_epoch_end(self, trainer, pl_module):
"""Called when the test epoch ends."""
pass

def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
pass
Expand Down
22 changes: 22 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from copy import deepcopy

import os
import numpy as np
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -140,12 +141,33 @@ def on_sanity_check_end(self, trainer, pl_module):
def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)

def on_train_epoch_end(self, trainer, pl_module):
# early stopping can also work in the train loop when there is no val loop and when using structured results
should_check_early_stop = False
train_es_key = 'early_stop_on'
if trainer.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True

val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key, None) is not None:
self.monitor = val_es_key
should_check_early_stop = True

if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics

if not self._validate_condition_metric(logs):
return # short circuit if metric not present

current = logs.get(self.monitor)

# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(current)

if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)

Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)

def _save_model(self, filepath):
def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)

Expand Down Expand Up @@ -270,6 +274,11 @@ def on_validation_end(self, trainer, pl_module):

metrics = trainer.callback_metrics
epoch = trainer.current_epoch

# support structured results
if metrics.get('checkpoint_on') is not None:
self.monitor = 'checkpoint_on'

if self.save_top_k == 0:
# no models are saved
return
Expand All @@ -281,7 +290,7 @@ def on_validation_end(self, trainer, pl_module):

if self.save_last:
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
self._save_model(filepath)
self._save_model(filepath, trainer, pl_module)

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
Expand All @@ -306,7 +315,7 @@ def on_validation_end(self, trainer, pl_module):
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')

Expand All @@ -315,9 +324,9 @@ def on_validation_end(self, trainer, pl_module):
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath)
self._save_model(filepath, trainer, pl_module)

def _do_check_save(self, filepath, current, epoch):
def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
# remove kth

del_list = []
Expand All @@ -343,7 +352,7 @@ def _do_check_save(self, filepath, current, epoch):
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
Expand Down
36 changes: 36 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,42 @@ def on_epoch_end(self) -> None:
"""
# do something when the epoch ends

def on_train_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
"""
# do something when the epoch starts

def on_train_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
# do something when the epoch ends

def on_validation_epoch_start(self) -> None:
"""
Called in the validation loop at the very beginning of the epoch.
"""
# do something when the epoch starts

def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
# do something when the epoch ends

def on_test_epoch_start(self) -> None:
"""
Called in the test loop at the very beginning of the epoch.
"""
# do something when the epoch starts

def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
# do something when the epoch ends

def on_pre_performance_check(self) -> None:
"""
Called at the very beginning of the validation loop.
Expand Down
Loading

0 comments on commit 6d10ac2

Please sign in to comment.