-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tracking of basic states in Trainer [wip - to-be-merged after v0.…
…9] (#2541) * Add initial tracking of states in Trainer. * Add INTERRUPTED state, improve tests, move state switching from callback to a trainer. * Move part of a trainer state switching to a decorator. * Add documentation. * Fix docs, rename state enum, restore state to previous on exit if None, add tests for decorator only. * Fix callback typing. Co-authored-by: William Falcon <waf2107@columbia.edu>
- Loading branch information
1 parent
13fe0a4
commit e9846dd
Showing
4 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from enum import Enum | ||
from functools import wraps | ||
from typing import Callable, Optional | ||
|
||
import pytorch_lightning | ||
|
||
|
||
class TrainerState(Enum): | ||
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer` | ||
to indicate what is currently or was executed. """ | ||
INITIALIZING = 'INITIALIZING' | ||
RUNNING = 'RUNNING' | ||
FINISHED = 'FINISHED' | ||
INTERRUPTED = 'INTERRUPTED' | ||
|
||
|
||
def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: | ||
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods | ||
which changes state to `entering` before the function execution and `exiting` | ||
after the function is executed. If `None` is passed to `entering`, the state is not changed. | ||
If `None` is passed to `exiting`, the state is restored to the state before function execution. | ||
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`. | ||
""" | ||
|
||
def wrapper(fn) -> Callable: | ||
@wraps(fn) | ||
def wrapped_fn(self, *args, **kwargs): | ||
if not isinstance(self, pytorch_lightning.Trainer): | ||
return fn(self, *args, **kwargs) | ||
|
||
state_before = self.state | ||
if entering is not None: | ||
self.state = entering | ||
result = fn(self, *args, **kwargs) | ||
|
||
# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted | ||
# we retain INTERRUPTED state | ||
if self.state == TrainerState.INTERRUPTED: | ||
return result | ||
|
||
if exiting is not None: | ||
self.state = exiting | ||
else: | ||
self.state = state_before | ||
return result | ||
|
||
return wrapped_fn | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
import pytest | ||
|
||
from pytorch_lightning import Trainer, Callback | ||
from pytorch_lightning.trainer.states import TrainerState, trainer_state | ||
from tests.base import EvalModelTemplate | ||
|
||
|
||
class StateSnapshotCallback(Callback): | ||
""" Allows to shapshot the state inside a particular trainer method. """ | ||
|
||
def __init__(self, snapshot_method: str): | ||
super().__init__() | ||
assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] | ||
self.snapshot_method = snapshot_method | ||
self.trainer_state = None | ||
|
||
def on_batch_start(self, trainer, pl_module): | ||
if self.snapshot_method == 'on_batch_start': | ||
self.trainer_state = trainer.state | ||
|
||
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): | ||
if self.snapshot_method == 'on_test_batch_start': | ||
self.trainer_state = trainer.state | ||
|
||
|
||
def test_state_decorator_nothing_passed(tmpdir): | ||
""" Test that state is not changed if nothing is passed to a decorator""" | ||
|
||
@trainer_state() | ||
def test_method(self): | ||
return self.state | ||
|
||
trainer = Trainer(default_root_dir=tmpdir) | ||
trainer.state = TrainerState.INITIALIZING | ||
|
||
snapshot_state = test_method(trainer) | ||
|
||
assert snapshot_state == TrainerState.INITIALIZING | ||
assert trainer.state == TrainerState.INITIALIZING | ||
|
||
|
||
def test_state_decorator_entering_only(tmpdir): | ||
""" Tests that state is set to entering inside a run function and restored to the previous value after. """ | ||
|
||
@trainer_state(entering=TrainerState.RUNNING) | ||
def test_method(self): | ||
return self.state | ||
|
||
trainer = Trainer(default_root_dir=tmpdir) | ||
trainer.state = TrainerState.INITIALIZING | ||
|
||
snapshot_state = test_method(trainer) | ||
|
||
assert snapshot_state == TrainerState.RUNNING | ||
assert trainer.state == TrainerState.INITIALIZING | ||
|
||
|
||
def test_state_decorator_exiting_only(tmpdir): | ||
""" Tests that state is not changed inside a run function and set to `exiting` after. """ | ||
|
||
@trainer_state(exiting=TrainerState.FINISHED) | ||
def test_method(self): | ||
return self.state | ||
|
||
trainer = Trainer(default_root_dir=tmpdir) | ||
trainer.state = TrainerState.INITIALIZING | ||
|
||
snapshot_state = test_method(trainer) | ||
|
||
assert snapshot_state == TrainerState.INITIALIZING | ||
assert trainer.state == TrainerState.FINISHED | ||
|
||
|
||
def test_state_decorator_entering_and_exiting(tmpdir): | ||
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """ | ||
|
||
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) | ||
def test_method(self): | ||
return self.state | ||
|
||
trainer = Trainer(default_root_dir=tmpdir) | ||
trainer.state = TrainerState.INITIALIZING | ||
|
||
snapshot_state = test_method(trainer) | ||
|
||
assert snapshot_state == TrainerState.RUNNING | ||
assert trainer.state == TrainerState.FINISHED | ||
|
||
|
||
def test_state_decorator_interrupt(tmpdir): | ||
""" Tests that state remains `INTERRUPTED` is its set in run function. """ | ||
|
||
@trainer_state(exiting=TrainerState.FINISHED) | ||
def test_method(self): | ||
self.state = TrainerState.INTERRUPTED | ||
|
||
trainer = Trainer(default_root_dir=tmpdir) | ||
trainer.state = TrainerState.INITIALIZING | ||
|
||
test_method(trainer) | ||
assert trainer.state == TrainerState.INTERRUPTED | ||
|
||
|
||
def test_initialize_state(tmpdir): | ||
""" Tests that state is INITIALIZE after Trainer creation """ | ||
trainer = Trainer(default_root_dir=tmpdir) | ||
assert trainer.state == TrainerState.INITIALIZING | ||
|
||
|
||
@pytest.mark.parametrize("extra_params", [ | ||
pytest.param(dict(fast_dev_run=True), id='Fast-Run'), | ||
pytest.param(dict(max_steps=1), id='Single-Step'), | ||
]) | ||
def test_running_state_during_fit(tmpdir, extra_params): | ||
""" Tests that state is set to RUNNING during fit """ | ||
|
||
hparams = EvalModelTemplate.get_default_hparams() | ||
model = EvalModelTemplate(**hparams) | ||
|
||
snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start') | ||
|
||
trainer = Trainer( | ||
callbacks=[snapshot_callback], | ||
default_root_dir=tmpdir, | ||
**extra_params | ||
) | ||
|
||
trainer.fit(model) | ||
|
||
assert snapshot_callback.trainer_state == TrainerState.RUNNING | ||
|
||
|
||
@pytest.mark.parametrize("extra_params", [ | ||
pytest.param(dict(fast_dev_run=True), id='Fast-Run'), | ||
pytest.param(dict(max_steps=1), id='Single-Step'), | ||
]) | ||
def test_finished_state_after_fit(tmpdir, extra_params): | ||
""" Tests that state is FINISHED after fit """ | ||
hparams = EvalModelTemplate.get_default_hparams() | ||
model = EvalModelTemplate(**hparams) | ||
|
||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
**extra_params | ||
) | ||
|
||
trainer.fit(model) | ||
|
||
assert trainer.state == TrainerState.FINISHED | ||
|
||
|
||
def test_running_state_during_test(tmpdir): | ||
""" Tests that state is set to RUNNING during test """ | ||
|
||
hparams = EvalModelTemplate.get_default_hparams() | ||
model = EvalModelTemplate(**hparams) | ||
|
||
snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start') | ||
|
||
trainer = Trainer( | ||
callbacks=[snapshot_callback], | ||
default_root_dir=tmpdir, | ||
fast_dev_run=True, | ||
) | ||
|
||
trainer.test(model) | ||
|
||
assert snapshot_callback.trainer_state == TrainerState.RUNNING | ||
|
||
|
||
def test_finished_state_after_test(tmpdir): | ||
""" Tests that state is FINISHED after fit """ | ||
hparams = EvalModelTemplate.get_default_hparams() | ||
model = EvalModelTemplate(**hparams) | ||
|
||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
fast_dev_run=True, | ||
) | ||
|
||
trainer.test(model) | ||
|
||
assert trainer.state == TrainerState.FINISHED | ||
|
||
|
||
@pytest.mark.parametrize("extra_params", [ | ||
pytest.param(dict(fast_dev_run=True), id='Fast-Run'), | ||
pytest.param(dict(max_steps=1), id='Single-Step'), | ||
]) | ||
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): | ||
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """ | ||
hparams = EvalModelTemplate.get_default_hparams() | ||
model = EvalModelTemplate(**hparams) | ||
|
||
class InterruptCallback(Callback): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def on_batch_start(self, trainer, pl_module): | ||
raise KeyboardInterrupt | ||
|
||
trainer = Trainer( | ||
callbacks=[InterruptCallback()], | ||
default_root_dir=tmpdir, | ||
**extra_params | ||
) | ||
|
||
trainer.fit(model) | ||
|
||
assert trainer.state == TrainerState.INTERRUPTED |