Skip to content

Commit

Permalink
Add tracking of basic states in Trainer [wip - to-be-merged after v0.…
Browse files Browse the repository at this point in the history
…9] (#2541)

* Add initial tracking of states in Trainer.

* Add INTERRUPTED state, improve tests, move state switching from callback to a trainer.

* Move part of a trainer state switching to a decorator.

* Add documentation.

* Fix docs, rename state enum, restore state to previous on exit if None, add tests for decorator only.

* Fix callback typing.

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
zerogerc and williamFalcon committed Aug 9, 2020
1 parent 13fe0a4 commit e9846dd
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 0 deletions.
49 changes: 49 additions & 0 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from enum import Enum
from functools import wraps
from typing import Callable, Optional

import pytorch_lightning


class TrainerState(Enum):
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed. """
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'


def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
which changes state to `entering` before the function execution and `exiting`
after the function is executed. If `None` is passed to `entering`, the state is not changed.
If `None` is passed to `exiting`, the state is restored to the state before function execution.
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`.
"""

def wrapper(fn) -> Callable:
@wraps(fn)
def wrapped_fn(self, *args, **kwargs):
if not isinstance(self, pytorch_lightning.Trainer):
return fn(self, *args, **kwargs)

state_before = self.state
if entering is not None:
self.state = entering
result = fn(self, *args, **kwargs)

# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
# we retain INTERRUPTED state
if self.state == TrainerState.INTERRUPTED:
return result

if exiting is not None:
self.state = exiting
else:
self.state = state_before
return result

return wrapped_fn

return wrapper
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
Expand Down Expand Up @@ -395,6 +396,7 @@ def __init__(
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self.state = TrainerState.INITIALIZING

self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir
Expand Down Expand Up @@ -888,6 +890,7 @@ def weights_save_path(self) -> str:
# -----------------------------
# MODEL TRAINING
# -----------------------------
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def fit(
self,
model: LightningModule,
Expand Down Expand Up @@ -1240,6 +1243,7 @@ def _run_sanity_check(self, ref_model, model):
self.on_sanity_check_end()
self.running_sanity_check = False

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test(
self,
model: Optional[LightningModule] = None,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -253,6 +254,7 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState
amp_type: AMPType
on_tpu: bool

Expand Down Expand Up @@ -418,6 +420,7 @@ def train(self):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

self.run_training_teardown()
Expand Down
210 changes: 210 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import pytest

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from tests.base import EvalModelTemplate


class StateSnapshotCallback(Callback):
""" Allows to shapshot the state inside a particular trainer method. """

def __init__(self, snapshot_method: str):
super().__init__()
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
self.snapshot_method = snapshot_method
self.trainer_state = None

def on_batch_start(self, trainer, pl_module):
if self.snapshot_method == 'on_batch_start':
self.trainer_state = trainer.state

def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
if self.snapshot_method == 'on_test_batch_start':
self.trainer_state = trainer.state


def test_state_decorator_nothing_passed(tmpdir):
""" Test that state is not changed if nothing is passed to a decorator"""

@trainer_state()
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.INITIALIZING


def test_state_decorator_entering_only(tmpdir):
""" Tests that state is set to entering inside a run function and restored to the previous value after. """

@trainer_state(entering=TrainerState.RUNNING)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.INITIALIZING


def test_state_decorator_exiting_only(tmpdir):
""" Tests that state is not changed inside a run function and set to `exiting` after. """

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.FINISHED


def test_state_decorator_entering_and_exiting(tmpdir):
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.FINISHED


def test_state_decorator_interrupt(tmpdir):
""" Tests that state remains `INTERRUPTED` is its set in run function. """

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
self.state = TrainerState.INTERRUPTED

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

test_method(trainer)
assert trainer.state == TrainerState.INTERRUPTED


def test_initialize_state(tmpdir):
""" Tests that state is INITIALIZE after Trainer creation """
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.state == TrainerState.INITIALIZING


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_running_state_during_fit(tmpdir, extra_params):
""" Tests that state is set to RUNNING during fit """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
**extra_params
)

trainer.fit(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_finished_state_after_fit(tmpdir, extra_params):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
**extra_params
)

trainer.fit(model)

assert trainer.state == TrainerState.FINISHED


def test_running_state_during_test(tmpdir):
""" Tests that state is set to RUNNING during test """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.test(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


def test_finished_state_after_test(tmpdir):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.test(model)

assert trainer.state == TrainerState.FINISHED


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

class InterruptCallback(Callback):
def __init__(self):
super().__init__()

def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

trainer = Trainer(
callbacks=[InterruptCallback()],
default_root_dir=tmpdir,
**extra_params
)

trainer.fit(model)

assert trainer.state == TrainerState.INTERRUPTED

0 comments on commit e9846dd

Please sign in to comment.