Skip to content

Commit

Permalink
Add state_dict to loops (#8197)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
tchaton and carmocca authored Jul 1, 2021
1 parent c0caeb3 commit d51b0ae
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 15 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fault-tolerant training
* Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197))


- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Add `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
- Added `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Dict, Optional

from deprecate import void

import pytorch_lightning as pl
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class Loop(ABC):
Expand Down Expand Up @@ -59,6 +60,10 @@ def skip(self) -> bool:
def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
# TODO(@justusschock): Make the trainer a weakref/proxy
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
self.trainer = trainer

def on_skip(self) -> Optional[Any]:
Expand Down Expand Up @@ -128,3 +133,10 @@ def on_run_end(self) -> Any:

def teardown(self) -> None:
"""The very last method called inside :meth:`run`. Use to release memory etc."""

def load_state_dict(self, state_dict: Dict) -> None:
"""Restore the loop state from the provided state_dict."""

def state_dict(self) -> Dict:
"""Return the loop current states."""
return {}
13 changes: 9 additions & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self, min_steps: int, max_steps: int):
self.batches_seen: int = 0
self.is_last_batch: Optional[bool] = None

self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional[loops.EvaluationLoop] = None
self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop()

self._dataloader_idx: Optional[int] = None
self._warning_cache: WarningCache = WarningCache()
Expand Down Expand Up @@ -80,9 +80,7 @@ def done(self) -> bool:
def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary parts like trainer and accelerators"""
super().connect(trainer, *args, **kwargs)
self.batch_loop = TrainingBatchLoop()
self.batch_loop.connect(trainer)
self.val_loop = loops.EvaluationLoop()
self.val_loop.connect(trainer)

def reset(self) -> None:
Expand Down Expand Up @@ -425,3 +423,10 @@ def _save_loggers_on_train_batch_end(self) -> None:
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

def state_dict(self) -> Dict:
return {"batch_loop": self.batch_loop.state_dict(), "val_loop": self.val_loop.state_dict()}

def load_state_dict(self, state_dict: Dict) -> None:
self.batch_loop.load_state_dict(state_dict["batch_loop"])
self.val_loop.load_state_dict(state_dict["val_loop"])
14 changes: 13 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from contextlib import suppress
from typing import Any, Optional
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.loops import Loop
Expand Down Expand Up @@ -97,6 +97,12 @@ def min_steps(self) -> int:
"""Returns the minimum numnber of steps to run"""
return self.epoch_loop.min_steps

@min_steps.setter
def min_steps(self, value: int) -> None:
"""Sets the minimum number of steps (forwards to epoch_loop)"""
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
self.epoch_loop.min_steps = value

@property
def max_steps(self) -> int:
"""Returns the maximum number of steps to run"""
Expand Down Expand Up @@ -274,3 +280,9 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False)

for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def state_dict(self) -> Dict:
return {"epoch_loop": self.epoch_loop.state_dict()}

def load_state_dict(self, state_dict: Dict) -> None:
self.epoch_loop.load_state_dict(state_dict["epoch_loop"])
Original file line number Diff line number Diff line change
Expand Up @@ -316,5 +316,5 @@ def progress_bar_metrics(self) -> Dict[str, float]:
def teardown(self):
self.trainer.fit_loop.epoch_loop._results.cpu()
self.trainer.fit_loop.epoch_loop.val_loop._results.cpu()
self.trainer.validation_loop._results.cpu()
self.trainer.validate_loop._results.cpu()
self.trainer.test_loop._results.cpu()
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class TrainingEpochProgress(EpochProgress):
current: Tracks the current epoch progress.
batch: Tracks batch progress.
optim: Tracks optimization progress.
val: Tracks validation_loop progress.
val: Tracks val_loop progress.
"""

optim: OptimizationProgress = field(default_factory=OptimizationProgress)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TrainerProperties(ABC):
logger_connector: LoggerConnector
state: TrainerState
fit_loop: FitLoop
validation_loop: EvaluationLoop
validate_loop: EvaluationLoop
test_loop: EvaluationLoop
"""
Accelerator properties
Expand Down Expand Up @@ -493,7 +493,7 @@ def evaluation_loop(self) -> EvaluationLoop:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
return self.fit_loop.epoch_loop.val_loop
elif self.state.fn == TrainerFn.VALIDATING:
return self.validation_loop
return self.validate_loop
if self.state.fn == TrainerFn.TESTING:
return self.test_loop
raise RuntimeError("The `Trainer.evaluation_loop` property isn't defined. Accessed outside of scope")
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ def __init__(
self.tuner = Tuner(self)

self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps)
self.validation_loop = EvaluationLoop()
self.validate_loop = EvaluationLoop()
self.test_loop = EvaluationLoop()
self.predict_loop = PredictionLoop()
self.fit_loop.connect(self)
self.validation_loop.connect(self)
self.validate_loop.connect(self)
self.test_loop.connect(self)
self.predict_loop.connect(self)

Expand Down
Empty file added tests/loops/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from pytorch_lightning.loops import FitLoop
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_loops_state_dict():
fit_loop = FitLoop()
with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
fit_loop.connect(object()) # noqa

fit_loop.connect(Trainer())
state_dict = fit_loop.state_dict()
new_fit_loop = FitLoop()
new_fit_loop.load_state_dict(state_dict)
assert fit_loop.state_dict() == new_fit_loop.state_dict()


def test_loops_state_dict_structure():
trainer = Trainer()
# structure saved by the checkpoint connector
state_dict = {
"fit_loop": trainer.fit_loop.state_dict(),
"validate_loop": trainer.validate_loop.state_dict(),
"test_loop": trainer.test_loop.state_dict(),
"predict_loop": trainer.predict_loop.state_dict(),
}
expected = {
"fit_loop": {
'epoch_loop': {
'batch_loop': {},
'val_loop': {},
}
},
"validate_loop": {},
"test_loop": {},
"predict_loop": {},
}
assert state_dict == expected

0 comments on commit d51b0ae

Please sign in to comment.