diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py new file mode 100644 index 0000000000000..3b51c1282141c --- /dev/null +++ b/pytorch_lightning/trainer/states.py @@ -0,0 +1,49 @@ +from enum import Enum +from functools import wraps +from typing import Callable, Optional + +import pytorch_lightning + + +class TrainerState(Enum): + """ State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer` + to indicate what is currently or was executed. """ + INITIALIZING = 'INITIALIZING' + RUNNING = 'RUNNING' + FINISHED = 'FINISHED' + INTERRUPTED = 'INTERRUPTED' + + +def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: + """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods + which changes state to `entering` before the function execution and `exiting` + after the function is executed. If `None` is passed to `entering`, the state is not changed. + If `None` is passed to `exiting`, the state is restored to the state before function execution. + If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`. + """ + + def wrapper(fn) -> Callable: + @wraps(fn) + def wrapped_fn(self, *args, **kwargs): + if not isinstance(self, pytorch_lightning.Trainer): + return fn(self, *args, **kwargs) + + state_before = self.state + if entering is not None: + self.state = entering + result = fn(self, *args, **kwargs) + + # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted + # we retain INTERRUPTED state + if self.state == TrainerState.INTERRUPTED: + return result + + if exiting is not None: + self.state = exiting + else: + self.state = state_before + return result + + return wrapped_fn + + return wrapper diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b8d246b5c8f25..2593ab786bf04 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -45,6 +45,7 @@ from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +from pytorch_lightning.trainer.states import TrainerState, trainer_state from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin @@ -395,6 +396,7 @@ def __init__( self.interrupted = False self.should_stop = False self.running_sanity_check = False + self.state = TrainerState.INITIALIZING self._default_root_dir = default_root_dir or os.getcwd() self._weights_save_path = weights_save_path or self._default_root_dir @@ -888,6 +890,7 @@ def weights_save_path(self) -> str: # ----------------------------- # MODEL TRAINING # ----------------------------- + @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) def fit( self, model: LightningModule, @@ -1240,6 +1243,7 @@ def _run_sanity_check(self, ref_model, model): self.on_sanity_check_end() self.running_sanity_check = False + @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) def test( self, model: Optional[LightningModule] = None, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ea9f915a8df14..ec5bd0938d15c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -174,6 +174,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -253,6 +254,7 @@ class TrainerTrainLoopMixin(ABC): terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... + state: TrainerState amp_type: AMPType on_tpu: bool @@ -418,6 +420,7 @@ def train(self): # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True + self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() self.run_training_teardown() diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py new file mode 100644 index 0000000000000..2b2ad545c7539 --- /dev/null +++ b/tests/trainer/test_states.py @@ -0,0 +1,210 @@ +import pytest + +from pytorch_lightning import Trainer, Callback +from pytorch_lightning.trainer.states import TrainerState, trainer_state +from tests.base import EvalModelTemplate + + +class StateSnapshotCallback(Callback): + """ Allows to shapshot the state inside a particular trainer method. """ + + def __init__(self, snapshot_method: str): + super().__init__() + assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] + self.snapshot_method = snapshot_method + self.trainer_state = None + + def on_batch_start(self, trainer, pl_module): + if self.snapshot_method == 'on_batch_start': + self.trainer_state = trainer.state + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if self.snapshot_method == 'on_test_batch_start': + self.trainer_state = trainer.state + + +def test_state_decorator_nothing_passed(tmpdir): + """ Test that state is not changed if nothing is passed to a decorator""" + + @trainer_state() + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.INITIALIZING + assert trainer.state == TrainerState.INITIALIZING + + +def test_state_decorator_entering_only(tmpdir): + """ Tests that state is set to entering inside a run function and restored to the previous value after. """ + + @trainer_state(entering=TrainerState.RUNNING) + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.RUNNING + assert trainer.state == TrainerState.INITIALIZING + + +def test_state_decorator_exiting_only(tmpdir): + """ Tests that state is not changed inside a run function and set to `exiting` after. """ + + @trainer_state(exiting=TrainerState.FINISHED) + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.INITIALIZING + assert trainer.state == TrainerState.FINISHED + + +def test_state_decorator_entering_and_exiting(tmpdir): + """ Tests that state is set to `entering` inside a run function and set ot `exiting` after. """ + + @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.RUNNING + assert trainer.state == TrainerState.FINISHED + + +def test_state_decorator_interrupt(tmpdir): + """ Tests that state remains `INTERRUPTED` is its set in run function. """ + + @trainer_state(exiting=TrainerState.FINISHED) + def test_method(self): + self.state = TrainerState.INTERRUPTED + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + test_method(trainer) + assert trainer.state == TrainerState.INTERRUPTED + + +def test_initialize_state(tmpdir): + """ Tests that state is INITIALIZE after Trainer creation """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.state == TrainerState.INITIALIZING + + +@pytest.mark.parametrize("extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), +]) +def test_running_state_during_fit(tmpdir, extra_params): + """ Tests that state is set to RUNNING during fit """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start') + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + **extra_params + ) + + trainer.fit(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + +@pytest.mark.parametrize("extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), +]) +def test_finished_state_after_fit(tmpdir, extra_params): + """ Tests that state is FINISHED after fit """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + trainer = Trainer( + default_root_dir=tmpdir, + **extra_params + ) + + trainer.fit(model) + + assert trainer.state == TrainerState.FINISHED + + +def test_running_state_during_test(tmpdir): + """ Tests that state is set to RUNNING during test """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start') + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.test(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + +def test_finished_state_after_test(tmpdir): + """ Tests that state is FINISHED after fit """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.test(model) + + assert trainer.state == TrainerState.FINISHED + + +@pytest.mark.parametrize("extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), +]) +def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): + """ Tests that state is set to INTERRUPTED on KeyboardInterrupt """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + class InterruptCallback(Callback): + def __init__(self): + super().__init__() + + def on_batch_start(self, trainer, pl_module): + raise KeyboardInterrupt + + trainer = Trainer( + callbacks=[InterruptCallback()], + default_root_dir=tmpdir, + **extra_params + ) + + trainer.fit(model) + + assert trainer.state == TrainerState.INTERRUPTED