diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index e7f5cfde80aa07..79af15d9532d6d 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -1,31 +1,9 @@ -from enum import Enum, auto - -from pytorch_lightning import Callback +from enum import Enum class TrainerState(Enum): - """ State which is set to the Trainer to indicate what is being executed. """ - INITIALIZE = auto() - RUNNING = auto() - FINISHED = auto() - - -class _TrainerStateSwitcher(Callback): - """ Special callback used by the Trainer. This callback sets proper - state to the trainer depending on what is being executed. - """ - - def on_init_start(self, trainer): - trainer.state = TrainerState.INITIALIZE - - def on_init_end(self, trainer): - trainer.state = TrainerState.INITIALIZE - - def setup(self, trainer, stage: str): - trainer.state = TrainerState.RUNNING - - def teardown(self, trainer, stage: str): - trainer.state = TrainerState.FINISHED - - def on_keyboard_interrupt(self, trainer, pl_module): - trainer.state = TrainerState.FINISHED + """ State which is set in the Trainer to indicate what is currently or was executed. """ + INITIALIZE = 'INITIALIZE' + RUNNING = 'RUNNING' + FINISHED = 'FINISHED' + INTERRUPTED = 'INTERRUPTED' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d4a770c9678025..c8d9e4f3eee3d6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -26,7 +26,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.states import _TrainerStateSwitcher, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin @@ -426,8 +426,6 @@ def __init__( # init callbacks self.callbacks = callbacks or [] - self.callbacks.append(_TrainerStateSwitcher()) - # configure early stop callback # creates a default one if none passed in early_stop_callback = self.configure_early_stopping(early_stop_callback) @@ -941,6 +939,8 @@ def fit( # check that model is configured correctly self.check_model_configuration(model) + self.state = TrainerState.RUNNING + # callbacks self.on_fit_start() if self.is_function_implemented('on_fit_start', model): @@ -1062,6 +1062,8 @@ def fit( if self.is_function_implemented('teardown'): model.teardown('fit') + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED # return 1 when finished # used for testing or when we need to know that training succeeded return results or 1 @@ -1313,6 +1315,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): if self.is_function_implemented('setup', model): model.setup('test') + self.state = TrainerState.RUNNING + # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( @@ -1351,6 +1355,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model_ref = self.get_model() model_ref.teardown('test') + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED + return results def __test_given_model(self, model, test_dataloaders): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fa493f2e1b09a5..989100444b9009 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -159,6 +159,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -240,6 +241,7 @@ class TrainerTrainLoopMixin(ABC): terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... + state: TrainerState # Callback system callbacks: List[Callback] @@ -397,6 +399,7 @@ def train(self): # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True + self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() self.run_training_teardown() diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index e5ce20d956ce1a..fab19deb5840ba 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -1,4 +1,4 @@ -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from pytorch_lightning.trainer.states import TrainerState from tests.base import EvalModelTemplate @@ -15,6 +15,35 @@ def test_initialize_state(tmpdir): assert trainer.state == TrainerState.INITIALIZE +def test_running_state_during_fit(tmpdir): + """ + Tests that state is set to RUNNING during fit + """ + + class StateSnapshotCallback(Callback): + def __init__(self): + super().__init__() + self.trainer_state = None + + def on_batch_start(self, trainer, pl_module): + self.trainer_state = trainer.state + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback() + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.fit(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + def test_finished_state_after_fit(tmpdir): """ Tests that state is FINISHED after fit @@ -32,6 +61,35 @@ def test_finished_state_after_fit(tmpdir): assert trainer.state == TrainerState.FINISHED +def test_running_state_during_test(tmpdir): + """ + Tests that state is set to RUNNING during test + """ + + class StateSnapshotCallback(Callback): + def __init__(self): + super().__init__() + self.trainer_state = None + + def on_test_batch_start(self, trainer, pl_module): + self.trainer_state = trainer.state + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback() + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.test(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + def test_finished_state_after_test(tmpdir): """ Tests that state is FINISHED after fit @@ -47,3 +105,28 @@ def test_finished_state_after_test(tmpdir): trainer.test(model) assert trainer.state == TrainerState.FINISHED + + +def test_interrupt_state_on_keyboard_interrupt(tmpdir): + """ + Tests that state is set to INTERRUPTED on KeyboardInterrupt + """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + class InterruptCallback(Callback): + def __init__(self): + super().__init__() + + def on_batch_start(self, trainer, pl_module): + raise KeyboardInterrupt + + trainer = Trainer( + callbacks=[InterruptCallback()], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.fit(model) + + assert trainer.state == TrainerState.INTERRUPTED