Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tracking of basic states in Trainer [wip - to-be-merged after v0.9] #2541

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from enum import Enum
from functools import wraps
from typing import Callable, Optional

import pytorch_lightning


class TrainerState(Enum):
zerogerc marked this conversation as resolved.
Show resolved Hide resolved
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed. """
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'


def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
which changes state to `entering` before the function execution and `exiting`
after the function is executed. If `None` is passed to `entering`, the state is not changed.
If `None` is passed to `exiting`, the state is restored to the state before function execution.
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`.
"""

def wrapper(fn) -> Callable:
@wraps(fn)
def wrapped_fn(self, *args, **kwargs):
if not isinstance(self, pytorch_lightning.Trainer):
return fn(self, *args, **kwargs)

state_before = self.state
if entering is not None:
self.state = entering
result = fn(self, *args, **kwargs)

# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
# we retain INTERRUPTED state
if self.state == TrainerState.INTERRUPTED:
return result

if exiting is not None:
self.state = exiting
else:
self.state = state_before
return result

return wrapped_fn

return wrapper
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
Expand Down Expand Up @@ -396,6 +397,7 @@ def __init__(
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self.state = TrainerState.INITIALIZING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you set trainer's self.state to TrainerState.INITIALIZING...(see next comment)


self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir
Expand Down Expand Up @@ -889,6 +891,7 @@ def weights_save_path(self) -> str:
# -----------------------------
# MODEL TRAINING
# -----------------------------
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def fit(
self,
model: LightningModule,
Expand Down Expand Up @@ -1241,6 +1244,7 @@ def _run_sanity_check(self, ref_model, model):
self.on_sanity_check_end()
self.running_sanity_check = False

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test(
self,
model: Optional[LightningModule] = None,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -256,6 +257,7 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState

# Callback system
callbacks: List[Callback]
Expand Down Expand Up @@ -419,6 +421,7 @@ def train(self):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that self.interrupted can be removed in the future?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, more of them could be cleaned...

self.state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

self.run_training_teardown()
Expand Down
210 changes: 210 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import pytest

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from tests.base import EvalModelTemplate


class StateSnapshotCallback(Callback):
""" Allows to shapshot the state inside a particular trainer method. """

def __init__(self, snapshot_method: str):
super().__init__()
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
self.snapshot_method = snapshot_method
self.trainer_state = None

def on_batch_start(self, trainer, pl_module):
if self.snapshot_method == 'on_batch_start':
self.trainer_state = trainer.state

def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
if self.snapshot_method == 'on_test_batch_start':
self.trainer_state = trainer.state


def test_state_decorator_nothing_passed(tmpdir):
""" Test that state is not changed if nothing is passed to a decorator"""

@trainer_state()
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.INITIALIZING


def test_state_decorator_entering_only(tmpdir):
""" Tests that state is set to entering inside a run function and restored to the previous value after. """

@trainer_state(entering=TrainerState.RUNNING)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you setting trainer.state = TrainerState.INITIALIZING if its set in the init of Trainer? Should this be ==?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trainer.state = TrainerState.INITIALIZING

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be wrong, just let me know what's going on here 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point by nate, we should make the state read only property! and only the trainer should to write to the internal attribute trainer._state


snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.INITIALIZING


def test_state_decorator_exiting_only(tmpdir):
""" Tests that state is not changed inside a run function and set to `exiting` after. """

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.FINISHED


def test_state_decorator_entering_and_exiting(tmpdir):
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.FINISHED


def test_state_decorator_interrupt(tmpdir):
""" Tests that state remains `INTERRUPTED` is its set in run function. """

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
self.state = TrainerState.INTERRUPTED

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

test_method(trainer)
assert trainer.state == TrainerState.INTERRUPTED


def test_initialize_state(tmpdir):
""" Tests that state is INITIALIZE after Trainer creation """
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.state == TrainerState.INITIALIZING


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_running_state_during_fit(tmpdir, extra_params):
""" Tests that state is set to RUNNING during fit """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
**extra_params
)

trainer.fit(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_finished_state_after_fit(tmpdir, extra_params):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
**extra_params
)

trainer.fit(model)

assert trainer.state == TrainerState.FINISHED


def test_running_state_during_test(tmpdir):
""" Tests that state is set to RUNNING during test """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.test(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


def test_finished_state_after_test(tmpdir):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.test(model)

assert trainer.state == TrainerState.FINISHED


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

class InterruptCallback(Callback):
def __init__(self):
super().__init__()

def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

trainer = Trainer(
callbacks=[InterruptCallback()],
default_root_dir=tmpdir,
**extra_params
)

trainer.fit(model)

assert trainer.state == TrainerState.INTERRUPTED