Skip to content

Commit

Permalink
Add INTERRUPTED state, improve tests, move state switching from callb…
Browse files Browse the repository at this point in the history
…ack to a trainer.
  • Loading branch information
zerogerc committed Jul 8, 2020
1 parent 482010c commit 17947ab
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 32 deletions.
34 changes: 6 additions & 28 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
from enum import Enum, auto

from pytorch_lightning import Callback
from enum import Enum


class TrainerState(Enum):
""" State which is set to the Trainer to indicate what is being executed. """
INITIALIZE = auto()
RUNNING = auto()
FINISHED = auto()


class _TrainerStateSwitcher(Callback):
""" Special callback used by the Trainer. This callback sets proper
state to the trainer depending on what is being executed.
"""

def on_init_start(self, trainer):
trainer.state = TrainerState.INITIALIZE

def on_init_end(self, trainer):
trainer.state = TrainerState.INITIALIZE

def setup(self, trainer, stage: str):
trainer.state = TrainerState.RUNNING

def teardown(self, trainer, stage: str):
trainer.state = TrainerState.FINISHED

def on_keyboard_interrupt(self, trainer, pl_module):
trainer.state = TrainerState.FINISHED
""" State which is set in the Trainer to indicate what is currently or was executed. """
INITIALIZE = 'INITIALIZE'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import _TrainerStateSwitcher, TrainerState
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
Expand Down Expand Up @@ -419,8 +419,6 @@ def __init__(
# init callbacks
self.callbacks = callbacks or []

self.callbacks.append(_TrainerStateSwitcher())

# configure early stop callback
# creates a default one if none passed in
early_stop_callback = self.configure_early_stopping(early_stop_callback)
Expand Down Expand Up @@ -914,6 +912,8 @@ def fit(
# check that model is configured correctly
self.check_model_configuration(model)

self.state = TrainerState.RUNNING

# callbacks
self.on_fit_start()
if self.is_function_implemented('on_fit_start', model):
Expand Down Expand Up @@ -1031,6 +1031,8 @@ def fit(
if self.is_function_implemented('teardown'):
model.teardown('fit')

if self.state != TrainerState.INTERRUPTED:
self.state = TrainerState.FINISHED
# return 1 when finished
# used for testing or when we need to know that training succeeded
return results or 1
Expand Down Expand Up @@ -1246,6 +1248,8 @@ def test(
if self.is_function_implemented('setup', model_ref):
model_ref.setup('test')

self.state = TrainerState.RUNNING

# if user requests the best checkpoint but we don't have it, error
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
Expand Down Expand Up @@ -1295,6 +1299,9 @@ def test(
model_ref = self.get_model()
model_ref.teardown('test')

if self.state != TrainerState.INTERRUPTED:
self.state = TrainerState.FINISHED

return results

def check_model_configuration(self, model: LightningModule):
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -240,6 +241,7 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState

# Callback system
callbacks: List[Callback]
Expand Down Expand Up @@ -397,6 +399,7 @@ def train(self):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

self.run_training_teardown()
Expand Down
85 changes: 84 additions & 1 deletion tests/trainer/test_states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate

Expand All @@ -15,6 +15,35 @@ def test_initialize_state(tmpdir):
assert trainer.state == TrainerState.INITIALIZE


def test_running_state_during_fit(tmpdir):
"""
Tests that state is set to RUNNING during fit
"""

class StateSnapshotCallback(Callback):
def __init__(self):
super().__init__()
self.trainer_state = None

def on_batch_start(self, trainer, pl_module):
self.trainer_state = trainer.state

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback()

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.fit(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


def test_finished_state_after_fit(tmpdir):
"""
Tests that state is FINISHED after fit
Expand All @@ -32,6 +61,35 @@ def test_finished_state_after_fit(tmpdir):
assert trainer.state == TrainerState.FINISHED


def test_running_state_during_test(tmpdir):
"""
Tests that state is set to RUNNING during test
"""

class StateSnapshotCallback(Callback):
def __init__(self):
super().__init__()
self.trainer_state = None

def on_test_batch_start(self, trainer, pl_module):
self.trainer_state = trainer.state

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback()

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.test(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


def test_finished_state_after_test(tmpdir):
"""
Tests that state is FINISHED after fit
Expand All @@ -47,3 +105,28 @@ def test_finished_state_after_test(tmpdir):
trainer.test(model)

assert trainer.state == TrainerState.FINISHED


def test_interrupt_state_on_keyboard_interrupt(tmpdir):
"""
Tests that state is set to INTERRUPTED on KeyboardInterrupt
"""
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

class InterruptCallback(Callback):
def __init__(self):
super().__init__()

def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

trainer = Trainer(
callbacks=[InterruptCallback()],
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.fit(model)

assert trainer.state == TrainerState.INTERRUPTED

0 comments on commit 17947ab

Please sign in to comment.