Skip to content

Commit

Permalink
Add initial tracking of states in Trainer.
Browse files Browse the repository at this point in the history
  • Loading branch information
zerogerc committed Jul 16, 2020
1 parent 7b4db30 commit 605a527
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
31 changes: 31 additions & 0 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from enum import Enum, auto

from pytorch_lightning import Callback


class TrainerState(Enum):
""" State which is set to the Trainer to indicate what is being executed. """
INITIALIZE = auto()
RUNNING = auto()
FINISHED = auto()


class _TrainerStateSwitcher(Callback):
""" Special callback used by the Trainer. This callback sets proper
state to the trainer depending on what is being executed.
"""

def on_init_start(self, trainer):
trainer.state = TrainerState.INITIALIZE

def on_init_end(self, trainer):
trainer.state = TrainerState.INITIALIZE

def setup(self, trainer, stage: str):
trainer.state = TrainerState.RUNNING

def teardown(self, trainer, stage: str):
trainer.state = TrainerState.FINISHED

def on_keyboard_interrupt(self, trainer, pl_module):
trainer.state = TrainerState.FINISHED
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import _TrainerStateSwitcher, TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
Expand Down Expand Up @@ -415,6 +416,7 @@ def __init__(
self.current_epoch = 0
self.interrupted = False
self.should_stop = False
self.state = TrainerState.INITIALIZE

# set default save path if user didn't provide one
if default_root_dir is None:
Expand All @@ -424,6 +426,8 @@ def __init__(
# init callbacks
self.callbacks = callbacks or []

self.callbacks.append(_TrainerStateSwitcher())

# configure early stop callback
# creates a default one if none passed in
early_stop_callback = self.configure_early_stopping(early_stop_callback)
Expand Down
49 changes: 49 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate


def test_initialize_state(tmpdir):
"""
Tests that state is INITIALIZE after Trainer creation
"""
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

assert trainer.state == TrainerState.INITIALIZE


def test_finished_state_after_fit(tmpdir):
"""
Tests that state is FINISHED after fit
"""
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.fit(model)

assert trainer.state == TrainerState.FINISHED


def test_finished_state_after_test(tmpdir):
"""
Tests that state is FINISHED after fit
"""
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.test(model)

assert trainer.state == TrainerState.FINISHED

0 comments on commit 605a527

Please sign in to comment.