Skip to content

Commit

Permalink
Trainer.interupted
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Aug 7, 2020
1 parent 0a1e9ef commit c1c96db
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC
from typing import Union

from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn


Expand Down Expand Up @@ -101,3 +102,14 @@ def ckpt_path(self, path: str):
rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
self._weights_save_path = path


class TrainerDeprecatedAPITillVer1_0(ABC):
state: TrainerState

@property
def interrupted(self) -> int:
"""Back compatibility, will be removed in v1.0"""
rank_zero_warn("Attribute `interrupted` is now set by `state` since v0.9.0"
" and this method will be removed in v1.0", DeprecationWarning)
return self.state == TrainerState.INTERRUPTED
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10, TrainerDeprecatedAPITillVer1_0
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin, _parse_gpu_ids, _parse_tpu_cores,
determine_root_gpu_device, pick_multiple_gpus)
Expand Down Expand Up @@ -100,6 +100,7 @@ class Trainer(
TrainerCallbackConfigMixin,
TrainerLRFinderMixin,
TrainerDeprecatedAPITillVer0_10,
TrainerDeprecatedAPITillVer1_0,
):
"""
Example:
Expand Down Expand Up @@ -394,7 +395,6 @@ def __init__(
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self.state = TrainerState.INITIALIZING
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ def train(self):
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')

# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
if self.state != TrainerState.INTERRUPTED:
self.state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

Expand Down
5 changes: 3 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -714,10 +715,10 @@ def on_keyboard_interrupt(self, trainer, pl_module):
logger=False,
default_root_dir=tmpdir,
)
assert not trainer.interrupted
assert trainer.state != TrainerState.INTERRUPTED
assert handle_interrupt_callback.exc_info is None
trainer.fit(model)
assert trainer.interrupted
assert trainer.state == TrainerState.INTERRUPTED
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)


Expand Down

0 comments on commit c1c96db

Please sign in to comment.