From d51b0ae7fc3157463c99e2be14072fc7a0b0794a Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 1 Jul 2021 16:54:37 +0100 Subject: [PATCH] Add `state_dict` to loops (#8197) Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 7 +-- pytorch_lightning/loops/base.py | 14 ++++- .../loops/epoch/training_epoch_loop.py | 13 +++-- pytorch_lightning/loops/fit_loop.py | 14 ++++- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/progress.py | 2 +- pytorch_lightning/trainer/properties.py | 4 +- pytorch_lightning/trainer/trainer.py | 4 +- tests/loops/__init__.py | 0 tests/loops/test_loop_state_dict.py | 54 +++++++++++++++++++ 10 files changed, 99 insertions(+), 15 deletions(-) create mode 100644 tests/loops/__init__.py create mode 100644 tests/loops/test_loop_state_dict.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e5f9b758102d2..1cb387581a1a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,13 +84,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training - * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) -- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) +- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) -- Add `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) +- Added `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1d976aa3cd079..1edc997e715ce 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,11 +13,12 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Dict, Optional from deprecate import void import pytorch_lightning as pl +from pytorch_lightning.utilities.exceptions import MisconfigurationException class Loop(ABC): @@ -59,6 +60,10 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects Loop with all the necessary things like connectors and accelerators.""" # TODO(@justusschock): Make the trainer a weakref/proxy + if not isinstance(trainer, pl.Trainer): + raise MisconfigurationException( + f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." + ) self.trainer = trainer def on_skip(self) -> Optional[Any]: @@ -128,3 +133,10 @@ def on_run_end(self) -> Any: def teardown(self) -> None: """The very last method called inside :meth:`run`. Use to release memory etc.""" + + def load_state_dict(self, state_dict: Dict) -> None: + """Restore the loop state from the provided state_dict.""" + + def state_dict(self) -> Dict: + """Return the loop current states.""" + return {} diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index cd8b992b09d45..89891c0d6148a 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -47,8 +47,8 @@ def __init__(self, min_steps: int, max_steps: int): self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None - self.batch_loop: Optional[TrainingBatchLoop] = None - self.val_loop: Optional[loops.EvaluationLoop] = None + self.batch_loop = TrainingBatchLoop() + self.val_loop = loops.EvaluationLoop() self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() @@ -80,9 +80,7 @@ def done(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with all necessary parts like trainer and accelerators""" super().connect(trainer, *args, **kwargs) - self.batch_loop = TrainingBatchLoop() self.batch_loop.connect(trainer) - self.val_loop = loops.EvaluationLoop() self.val_loop.connect(trainer) def reset(self) -> None: @@ -425,3 +423,10 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() + + def state_dict(self) -> Dict: + return {"batch_loop": self.batch_loop.state_dict(), "val_loop": self.val_loop.state_dict()} + + def load_state_dict(self, state_dict: Dict) -> None: + self.batch_loop.load_state_dict(state_dict["batch_loop"]) + self.val_loop.load_state_dict(state_dict["val_loop"]) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 655e102466931..bf42663fd5c9e 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Any, Optional +from typing import Any, Dict, Optional import pytorch_lightning as pl from pytorch_lightning.loops import Loop @@ -97,6 +97,12 @@ def min_steps(self) -> int: """Returns the minimum numnber of steps to run""" return self.epoch_loop.min_steps + @min_steps.setter + def min_steps(self, value: int) -> None: + """Sets the minimum number of steps (forwards to epoch_loop)""" + # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided + self.epoch_loop.min_steps = value + @property def max_steps(self) -> int: """Returns the maximum number of steps to run""" @@ -274,3 +280,9 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) for cb in callbacks: cb.on_validation_end(self.trainer, model) + + def state_dict(self) -> Dict: + return {"epoch_loop": self.epoch_loop.state_dict()} + + def load_state_dict(self, state_dict: Dict) -> None: + self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 88f089224ff2e..5a950d40f8f54 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -316,5 +316,5 @@ def progress_bar_metrics(self) -> Dict[str, float]: def teardown(self): self.trainer.fit_loop.epoch_loop._results.cpu() self.trainer.fit_loop.epoch_loop.val_loop._results.cpu() - self.trainer.validation_loop._results.cpu() + self.trainer.validate_loop._results.cpu() self.trainer.test_loop._results.cpu() diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index caf4ab0bf1599..2d7a1d7e8f53a 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -239,7 +239,7 @@ class TrainingEpochProgress(EpochProgress): current: Tracks the current epoch progress. batch: Tracks batch progress. optim: Tracks optimization progress. - val: Tracks validation_loop progress. + val: Tracks val_loop progress. """ optim: OptimizationProgress = field(default_factory=OptimizationProgress) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ea1164bdee861..b59066cb03b17 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -63,7 +63,7 @@ class TrainerProperties(ABC): logger_connector: LoggerConnector state: TrainerState fit_loop: FitLoop - validation_loop: EvaluationLoop + validate_loop: EvaluationLoop test_loop: EvaluationLoop """ Accelerator properties @@ -493,7 +493,7 @@ def evaluation_loop(self) -> EvaluationLoop: if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): return self.fit_loop.epoch_loop.val_loop elif self.state.fn == TrainerFn.VALIDATING: - return self.validation_loop + return self.validate_loop if self.state.fn == TrainerFn.TESTING: return self.test_loop raise RuntimeError("The `Trainer.evaluation_loop` property isn't defined. Accessed outside of scope") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6e47702e38d67..008cdb2239df7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -344,11 +344,11 @@ def __init__( self.tuner = Tuner(self) self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) - self.validation_loop = EvaluationLoop() + self.validate_loop = EvaluationLoop() self.test_loop = EvaluationLoop() self.predict_loop = PredictionLoop() self.fit_loop.connect(self) - self.validation_loop.connect(self) + self.validate_loop.connect(self) self.test_loop.connect(self) self.predict_loop.connect(self) diff --git a/tests/loops/__init__.py b/tests/loops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py new file mode 100644 index 0000000000000..1930dc46566fd --- /dev/null +++ b/tests/loops/test_loop_state_dict.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from pytorch_lightning.loops import FitLoop +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_loops_state_dict(): + fit_loop = FitLoop() + with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): + fit_loop.connect(object()) # noqa + + fit_loop.connect(Trainer()) + state_dict = fit_loop.state_dict() + new_fit_loop = FitLoop() + new_fit_loop.load_state_dict(state_dict) + assert fit_loop.state_dict() == new_fit_loop.state_dict() + + +def test_loops_state_dict_structure(): + trainer = Trainer() + # structure saved by the checkpoint connector + state_dict = { + "fit_loop": trainer.fit_loop.state_dict(), + "validate_loop": trainer.validate_loop.state_dict(), + "test_loop": trainer.test_loop.state_dict(), + "predict_loop": trainer.predict_loop.state_dict(), + } + expected = { + "fit_loop": { + 'epoch_loop': { + 'batch_loop': {}, + 'val_loop': {}, + } + }, + "validate_loop": {}, + "test_loop": {}, + "predict_loop": {}, + } + assert state_dict == expected