Skip to content

Commit

Permalink
Make trainer.state a read-only property (#3109)
Browse files Browse the repository at this point in the history
* Make trainer.state a read-only property

* Update states.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
zerogerc and Borda authored Aug 24, 2020
1 parent 8ebf4fe commit 2d42ec0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 16 deletions.
11 changes: 4 additions & 7 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,17 @@ def wrapped_fn(self, *args, **kwargs):
if not isinstance(self, pytorch_lightning.Trainer):
return fn(self, *args, **kwargs)

state_before = self.state
state_before = self._state
if entering is not None:
self.state = entering
self._state = entering
result = fn(self, *args, **kwargs)

# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
# we retain INTERRUPTED state
if self.state == TrainerState.INTERRUPTED:
if self._state == TrainerState.INTERRUPTED:
return result

if exiting is not None:
self.state = exiting
else:
self.state = state_before
self._state = exiting if exiting is not None else state_before
return result

return wrapped_fn
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def __init__(
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self.state = TrainerState.INITIALIZING
self._state = TrainerState.INITIALIZING

self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir
Expand Down Expand Up @@ -611,6 +611,10 @@ def __init__(
# Callback system
self.on_init_end()

@property
def state(self) -> TrainerState:
return self._state

@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState
_state: TrainerState
amp_backend: AMPType
on_tpu: bool
accelerator_backend: ...
Expand Down Expand Up @@ -418,7 +418,7 @@ def train(self):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.state = TrainerState.INTERRUPTED
self._state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

self.run_training_teardown()
Expand Down
7 changes: 1 addition & 6 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

Expand All @@ -47,7 +46,6 @@ def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

Expand All @@ -63,7 +61,6 @@ def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

Expand All @@ -79,7 +76,6 @@ def test_method(self):
return self.state

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

snapshot_state = test_method(trainer)

Expand All @@ -92,10 +88,9 @@ def test_state_decorator_interrupt(tmpdir):

@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
self.state = TrainerState.INTERRUPTED
self._state = TrainerState.INTERRUPTED

trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING

test_method(trainer)
assert trainer.state == TrainerState.INTERRUPTED
Expand Down

0 comments on commit 2d42ec0

Please sign in to comment.