diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 47f5772cd75c74..3b51c1282141cd 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -6,17 +6,20 @@ class TrainerState(Enum): - """ State which is set in the Trainer to indicate what is currently or was executed. """ - INITIALIZE = 'INITIALIZE' + """ State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer` + to indicate what is currently or was executed. """ + INITIALIZING = 'INITIALIZING' RUNNING = 'RUNNING' FINISHED = 'FINISHED' INTERRUPTED = 'INTERRUPTED' def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: - """ Decorator for :class:`~pytorch_lightning.Trainer` methods which changes - state to `entering` before the function execution and `exiting` after - the function is executed. If None is passed the state is not changed. + """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods + which changes state to `entering` before the function execution and `exiting` + after the function is executed. If `None` is passed to `entering`, the state is not changed. + If `None` is passed to `exiting`, the state is restored to the state before function execution. + If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`. """ def wrapper(fn) -> Callable: @@ -25,14 +28,20 @@ def wrapped_fn(self, *args, **kwargs): if not isinstance(self, pytorch_lightning.Trainer): return fn(self, *args, **kwargs) + state_before = self.state if entering is not None: self.state = entering result = fn(self, *args, **kwargs) # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted # we retain INTERRUPTED state - if exiting is not None and self.state != TrainerState.INTERRUPTED: + if self.state == TrainerState.INTERRUPTED: + return result + + if exiting is not None: self.state = exiting + else: + self.state = state_before return result return wrapped_fn diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 807636c969f0cd..bca8835607038b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -417,7 +417,7 @@ def __init__( self.current_epoch = 0 self.interrupted = False self.should_stop = False - self.state = TrainerState.INITIALIZE + self.state = TrainerState.INITIALIZING # set default save path if user didn't provide one if default_root_dir is None: diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index fab19deb5840ba..5e2ee61687bf2c 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -1,42 +1,128 @@ +import pytest + from pytorch_lightning import Trainer, Callback -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import TrainerState, trainer_state from tests.base import EvalModelTemplate -def test_initialize_state(tmpdir): - """ - Tests that state is INITIALIZE after Trainer creation - """ - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - ) +class StateSnapshotCallback(Callback): + """ Allows to shapshot the state inside a particular trainer method. """ + + def __init__(self, snapshot_method: str): + super().__init__() + assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] + self.snapshot_method = snapshot_method + self.trainer_state = None - assert trainer.state == TrainerState.INITIALIZE + def on_batch_start(self, trainer, pl_module): + if self.snapshot_method == 'on_batch_start': + self.trainer_state = trainer.state + def on_test_batch_start(self, trainer, pl_module): + if self.snapshot_method == 'on_test_batch_start': + self.trainer_state = trainer.state -def test_running_state_during_fit(tmpdir): - """ - Tests that state is set to RUNNING during fit - """ - class StateSnapshotCallback(Callback): - def __init__(self): - super().__init__() - self.trainer_state = None +def test_state_decorator_nothing_passed(tmpdir): + """ Test that state is not changed if nothing is passed to a decorator""" - def on_batch_start(self, trainer, pl_module): - self.trainer_state = trainer.state + @trainer_state() + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.INITIALIZING + assert trainer.state == TrainerState.INITIALIZING + + +def test_state_decorator_entering_only(tmpdir): + """ Tests that state is set to entering inside a run function and restored to the previous value after. """ + + @trainer_state(entering=TrainerState.RUNNING) + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.RUNNING + assert trainer.state == TrainerState.INITIALIZING + + +def test_state_decorator_exiting_only(tmpdir): + """ Tests that state is not changed inside a run function and set to `exiting` after. """ + + @trainer_state(exiting=TrainerState.FINISHED) + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.INITIALIZING + assert trainer.state == TrainerState.FINISHED + + +def test_state_decorator_entering_and_exiting(tmpdir): + """ Tests that state is set to `entering` inside a run function and set ot `exiting` after. """ + + @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) + def test_method(self): + return self.state + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + snapshot_state = test_method(trainer) + + assert snapshot_state == TrainerState.RUNNING + assert trainer.state == TrainerState.FINISHED + + +def test_state_decorator_interrupt(tmpdir): + """ Tests that state remains `INTERRUPTED` is its set in run function. """ + + @trainer_state(exiting=TrainerState.FINISHED) + def test_method(self): + self.state = TrainerState.INTERRUPTED + + trainer = Trainer(default_root_dir=tmpdir) + trainer.state = TrainerState.INITIALIZING + + test_method(trainer) + assert trainer.state == TrainerState.INTERRUPTED + + +def test_initialize_state(tmpdir): + """ Tests that state is INITIALIZE after Trainer creation """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.state == TrainerState.INITIALIZING + + +@pytest.mark.parametrize("extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), +]) +def test_running_state_during_fit(tmpdir, extra_params): + """ Tests that state is set to RUNNING during fit """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) - snapshot_callback = StateSnapshotCallback() + snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start') trainer = Trainer( callbacks=[snapshot_callback], default_root_dir=tmpdir, - fast_dev_run=True, + **extra_params ) trainer.fit(model) @@ -44,16 +130,18 @@ def on_batch_start(self, trainer, pl_module): assert snapshot_callback.trainer_state == TrainerState.RUNNING -def test_finished_state_after_fit(tmpdir): - """ - Tests that state is FINISHED after fit - """ +@pytest.mark.parametrize("extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), +]) +def test_finished_state_after_fit(tmpdir, extra_params): + """ Tests that state is FINISHED after fit """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) trainer = Trainer( default_root_dir=tmpdir, - fast_dev_run=True, + **extra_params ) trainer.fit(model) @@ -62,22 +150,12 @@ def test_finished_state_after_fit(tmpdir): def test_running_state_during_test(tmpdir): - """ - Tests that state is set to RUNNING during test - """ - - class StateSnapshotCallback(Callback): - def __init__(self): - super().__init__() - self.trainer_state = None - - def on_test_batch_start(self, trainer, pl_module): - self.trainer_state = trainer.state + """ Tests that state is set to RUNNING during test """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) - snapshot_callback = StateSnapshotCallback() + snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start') trainer = Trainer( callbacks=[snapshot_callback], @@ -91,9 +169,7 @@ def on_test_batch_start(self, trainer, pl_module): def test_finished_state_after_test(tmpdir): - """ - Tests that state is FINISHED after fit - """ + """ Tests that state is FINISHED after fit """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) @@ -107,10 +183,12 @@ def test_finished_state_after_test(tmpdir): assert trainer.state == TrainerState.FINISHED -def test_interrupt_state_on_keyboard_interrupt(tmpdir): - """ - Tests that state is set to INTERRUPTED on KeyboardInterrupt - """ +@pytest.mark.parametrize("extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), +]) +def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): + """ Tests that state is set to INTERRUPTED on KeyboardInterrupt """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) @@ -124,7 +202,7 @@ def on_batch_start(self, trainer, pl_module): trainer = Trainer( callbacks=[InterruptCallback()], default_root_dir=tmpdir, - fast_dev_run=True, + **extra_params ) trainer.fit(model)