From 17947ab2a194a4f34ac78623b97c9387a3fc514d Mon Sep 17 00:00:00 2001 From: Uladzislau Sazanovich Date: Wed, 8 Jul 2020 11:30:30 +0300 Subject: [PATCH] Add INTERRUPTED state, improve tests, move state switching from callback to a trainer. --- pytorch_lightning/trainer/states.py | 34 ++------- pytorch_lightning/trainer/trainer.py | 13 +++- pytorch_lightning/trainer/training_loop.py | 3 + tests/trainer/test_states.py | 85 +++++++++++++++++++++- 4 files changed, 103 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index e7f5cfde80aa07..79af15d9532d6d 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -1,31 +1,9 @@ -from enum import Enum, auto - -from pytorch_lightning import Callback +from enum import Enum class TrainerState(Enum): - """ State which is set to the Trainer to indicate what is being executed. """ - INITIALIZE = auto() - RUNNING = auto() - FINISHED = auto() - - -class _TrainerStateSwitcher(Callback): - """ Special callback used by the Trainer. This callback sets proper - state to the trainer depending on what is being executed. - """ - - def on_init_start(self, trainer): - trainer.state = TrainerState.INITIALIZE - - def on_init_end(self, trainer): - trainer.state = TrainerState.INITIALIZE - - def setup(self, trainer, stage: str): - trainer.state = TrainerState.RUNNING - - def teardown(self, trainer, stage: str): - trainer.state = TrainerState.FINISHED - - def on_keyboard_interrupt(self, trainer, pl_module): - trainer.state = TrainerState.FINISHED + """ State which is set in the Trainer to indicate what is currently or was executed. """ + INITIALIZE = 'INITIALIZE' + RUNNING = 'RUNNING' + FINISHED = 'FINISHED' + INTERRUPTED = 'INTERRUPTED' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 676bad427ce317..82310122ed9cee 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -26,7 +26,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.states import _TrainerStateSwitcher, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin @@ -419,8 +419,6 @@ def __init__( # init callbacks self.callbacks = callbacks or [] - self.callbacks.append(_TrainerStateSwitcher()) - # configure early stop callback # creates a default one if none passed in early_stop_callback = self.configure_early_stopping(early_stop_callback) @@ -914,6 +912,8 @@ def fit( # check that model is configured correctly self.check_model_configuration(model) + self.state = TrainerState.RUNNING + # callbacks self.on_fit_start() if self.is_function_implemented('on_fit_start', model): @@ -1031,6 +1031,8 @@ def fit( if self.is_function_implemented('teardown'): model.teardown('fit') + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED # return 1 when finished # used for testing or when we need to know that training succeeded return results or 1 @@ -1246,6 +1248,8 @@ def test( if self.is_function_implemented('setup', model_ref): model_ref.setup('test') + self.state = TrainerState.RUNNING + # if user requests the best checkpoint but we don't have it, error if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( @@ -1295,6 +1299,9 @@ def test( model_ref = self.get_model() model_ref.teardown('test') + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED + return results def check_model_configuration(self, model: LightningModule): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index bd55881dd7f38c..36da4a4770e26c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -159,6 +159,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -240,6 +241,7 @@ class TrainerTrainLoopMixin(ABC): terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... + state: TrainerState # Callback system callbacks: List[Callback] @@ -397,6 +399,7 @@ def train(self): # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True + self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() self.run_training_teardown() diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index e5ce20d956ce1a..fab19deb5840ba 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -1,4 +1,4 @@ -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from pytorch_lightning.trainer.states import TrainerState from tests.base import EvalModelTemplate @@ -15,6 +15,35 @@ def test_initialize_state(tmpdir): assert trainer.state == TrainerState.INITIALIZE +def test_running_state_during_fit(tmpdir): + """ + Tests that state is set to RUNNING during fit + """ + + class StateSnapshotCallback(Callback): + def __init__(self): + super().__init__() + self.trainer_state = None + + def on_batch_start(self, trainer, pl_module): + self.trainer_state = trainer.state + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback() + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.fit(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + def test_finished_state_after_fit(tmpdir): """ Tests that state is FINISHED after fit @@ -32,6 +61,35 @@ def test_finished_state_after_fit(tmpdir): assert trainer.state == TrainerState.FINISHED +def test_running_state_during_test(tmpdir): + """ + Tests that state is set to RUNNING during test + """ + + class StateSnapshotCallback(Callback): + def __init__(self): + super().__init__() + self.trainer_state = None + + def on_test_batch_start(self, trainer, pl_module): + self.trainer_state = trainer.state + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback() + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.test(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + def test_finished_state_after_test(tmpdir): """ Tests that state is FINISHED after fit @@ -47,3 +105,28 @@ def test_finished_state_after_test(tmpdir): trainer.test(model) assert trainer.state == TrainerState.FINISHED + + +def test_interrupt_state_on_keyboard_interrupt(tmpdir): + """ + Tests that state is set to INTERRUPTED on KeyboardInterrupt + """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + class InterruptCallback(Callback): + def __init__(self): + super().__init__() + + def on_batch_start(self, trainer, pl_module): + raise KeyboardInterrupt + + trainer = Trainer( + callbacks=[InterruptCallback()], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.fit(model) + + assert trainer.state == TrainerState.INTERRUPTED