Skip to content

Commit

Permalink
Fix docs, rename state enum, restore state to previous on exit if Non…
Browse files Browse the repository at this point in the history
…e, add tests for decorator only.
  • Loading branch information
zerogerc committed Jul 21, 2020
1 parent 7542142 commit 47b3c6a
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 54 deletions.
21 changes: 15 additions & 6 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@


class TrainerState(Enum):
""" State which is set in the Trainer to indicate what is currently or was executed. """
INITIALIZE = 'INITIALIZE'
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed. """
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'


def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
""" Decorator for :class:`~pytorch_lightning.Trainer` methods which changes
state to `entering` before the function execution and `exiting` after
the function is executed. If None is passed the state is not changed.
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
which changes state to `entering` before the function execution and `exiting`
after the function is executed. If `None` is passed to `entering`, the state is not changed.
If `None` is passed to `exiting`, the state is restored to the state before function execution.
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`.
"""

def wrapper(fn) -> Callable:
Expand All @@ -25,14 +28,20 @@ def wrapped_fn(self, *args, **kwargs):
if not isinstance(self, pytorch_lightning.Trainer):
return fn(self, *args, **kwargs)

state_before = self.state
if entering is not None:
self.state = entering
result = fn(self, *args, **kwargs)

# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
# we retain INTERRUPTED state
if exiting is not None and self.state != TrainerState.INTERRUPTED:
if self.state == TrainerState.INTERRUPTED:
return result

if exiting is not None:
self.state = exiting
else:
self.state = state_before
return result

return wrapped_fn
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def __init__(
self.current_epoch = 0
self.interrupted = False
self.should_stop = False
self.state = TrainerState.INITIALIZE
self.state = TrainerState.INITIALIZING

# set default save path if user didn't provide one
if default_root_dir is None:
Expand Down
172 changes: 125 additions & 47 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,147 @@
import pytest

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from tests.base import EvalModelTemplate


def test_initialize_state(tmpdir):
"""
Tests that state is INITIALIZE after Trainer creation
"""
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
class StateSnapshotCallback(Callback):
""" Allows to shapshot the state inside a particular trainer method. """

def __init__(self, snapshot_method: str):
super().__init__()
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
self.snapshot_method = snapshot_method
self.trainer_state = None

assert trainer.state == TrainerState.INITIALIZE
def on_batch_start(self, trainer, pl_module):
if self.snapshot_method == 'on_batch_start':
self.trainer_state = trainer.state

def on_test_batch_start(self, trainer, pl_module):
if self.snapshot_method == 'on_test_batch_start':
self.trainer_state = trainer.state

def test_running_state_during_fit(tmpdir):
"""
Tests that state is set to RUNNING during fit
"""

class StateSnapshotCallback(Callback):
def __init__(self):
super().__init__()
self.trainer_state = None
def test_state_decorator_nothing_passed(tmpdir):
""" Test that state is not changed if nothing is passed to a decorator"""

def on_batch_start(self, trainer, pl_module):
self.trainer_state = trainer.state
@trainer_state()
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.INITIALIZING


def test_state_decorator_entering_only(tmpdir):
""" Tests that state is set to entering inside a run function and restored to the previous value after. """

@trainer_state(entering=TrainerState.RUNNING)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.INITIALIZING


def test_state_decorator_exiting_only(tmpdir):
""" Tests that state is not changed inside a run function and set to `exiting` after. """

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.FINISHED


def test_state_decorator_entering_and_exiting(tmpdir):
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.FINISHED


def test_state_decorator_interrupt(tmpdir):
""" Tests that state remains `INTERRUPTED` is its set in run function. """

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
self.state = TrainerState.INTERRUPTED

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

test_method(trainer)
assert trainer.state == TrainerState.INTERRUPTED


def test_initialize_state(tmpdir):
""" Tests that state is INITIALIZE after Trainer creation """
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.state == TrainerState.INITIALIZING


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_running_state_during_fit(tmpdir, extra_params):
""" Tests that state is set to RUNNING during fit """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback()
snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
**extra_params
)

trainer.fit(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


def test_finished_state_after_fit(tmpdir):
"""
Tests that state is FINISHED after fit
"""
@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_finished_state_after_fit(tmpdir, extra_params):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
**extra_params
)

trainer.fit(model)
Expand All @@ -62,22 +150,12 @@ def test_finished_state_after_fit(tmpdir):


def test_running_state_during_test(tmpdir):
"""
Tests that state is set to RUNNING during test
"""

class StateSnapshotCallback(Callback):
def __init__(self):
super().__init__()
self.trainer_state = None

def on_test_batch_start(self, trainer, pl_module):
self.trainer_state = trainer.state
""" Tests that state is set to RUNNING during test """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback()
snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
Expand All @@ -91,9 +169,7 @@ def on_test_batch_start(self, trainer, pl_module):


def test_finished_state_after_test(tmpdir):
"""
Tests that state is FINISHED after fit
"""
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

Expand All @@ -107,10 +183,12 @@ def test_finished_state_after_test(tmpdir):
assert trainer.state == TrainerState.FINISHED


def test_interrupt_state_on_keyboard_interrupt(tmpdir):
"""
Tests that state is set to INTERRUPTED on KeyboardInterrupt
"""
@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

Expand All @@ -124,7 +202,7 @@ def on_batch_start(self, trainer, pl_module):
trainer = Trainer(
callbacks=[InterruptCallback()],
default_root_dir=tmpdir,
fast_dev_run=True,
**extra_params
)

trainer.fit(model)
Expand Down

0 comments on commit 47b3c6a

Please sign in to comment.