From 842feb538bf3956b4b0f85d1c8c7023e65aa035d Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Mar 2021 16:09:50 +0000 Subject: [PATCH 1/5] move save_checkpoint responsability to accelerator --- pytorch_lightning/accelerators/accelerator.py | 5 + .../plugins/training_type/tpu_spawn.py | 9 +- .../training_type/training_type_plugin.py | 21 ++++ .../connectors/checkpoint_connector.py | 109 ++---------------- pytorch_lightning/utilities/cloud_io.py | 98 ++++++++++++++++ 5 files changed, 136 insertions(+), 106 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1dcd541ca0610..225016dfdf6b0 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,6 +17,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin @@ -466,3 +467,7 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: ' It will be removed in v1.5.' ) self.setup_precision_plugin(plugin) + + def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None: + # dump states as a checkpoint dictionary object + self.training_type_plugin.save_checkpoint(trainer, filepath, weights_only) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a8706d54cb5c9..6dbeddcff8213 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -19,11 +19,13 @@ import torch import torch.multiprocessing as mp +import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import dump_checkpoint from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -106,8 +108,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.accelerator.setup_optimizers(trainer) trainer.precision_plugin.connect(self._model, None, None) - # replace trainer save_checkpoint to use `xm.save` - trainer.save_checkpoint = self.save_checkpoint self.barrier("pre-run-stage") results = trainer.run_stage() @@ -298,14 +298,15 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.lightning_module.predict_step(*args, **kwargs) - def save_checkpoint(self, filepath, weights_only: bool = False): + def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: + trainer: PyTorch Lightning Trainer filepath: write-target file's path weights_only: saving model weights only """ # dump states as a checkpoint dictionary object - _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) + _checkpoint = dump_checkpoint(trainer, weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 08dca63a7c925..97250dd5917fc 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -19,9 +19,12 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save, dump_checkpoint if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer @@ -192,3 +195,21 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. """ return False + + def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None: + # dump states as a checkpoint dictionary object + checkpoint = dump_checkpoint(trainer, weights_only) + if trainer.is_global_zero: + # write the checkpoint dictionary on the file + + checkpoint = self.on_save(checkpoint) + try: + atomic_save(checkpoint, filepath) + except AttributeError as err: + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + rank_zero_warn( + 'Warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' + ) + atomic_save(checkpoint, filepath) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 60c76b70bba50..77f511808bf12 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -19,17 +19,9 @@ import torch -import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import ( - _APEX_AVAILABLE, - _OMEGACONF_AVAILABLE, - AMPType, - DeviceType, - rank_zero_info, - rank_zero_warn, -) -from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, DeviceType, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save, dump_checkpoint, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -37,9 +29,6 @@ if _APEX_AVAILABLE: from apex import amp -if _OMEGACONF_AVAILABLE: - from omegaconf import Container - class CheckpointConnector: @@ -236,94 +225,6 @@ def hpc_save(self, folderpath: str, logger): return filepath - def dump_checkpoint(self, weights_only: bool = False) -> dict: - """Creating a model checkpoint dictionary object from various component states. - - Args: - weights_only: saving model weights only - - Return: - structured dictionary: { - 'epoch': training epoch - 'global_step': training global step - 'pytorch-lightning_version': PyTorch Lightning's version - 'callbacks': "callback specific state"[] # if not weights_only - 'optimizer_states': "PT optim's state_dict"[] # if not weights_only - 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp - 'state_dict': Model's state_dict (e.g. network weights) - CHECKPOINT_HYPER_PARAMS_NAME: - CHECKPOINT_HYPER_PARAMS_KEY: - CHECKPOINT_HYPER_PARAMS_TYPE: - something_cool_i_want_to_save: anything you define through model.on_save_checkpoint - LightningDataModule.__class__.__name__: pl DataModule's state - } - """ - - # dump epoch/global_step/pytorch-lightning_version - current_epoch = self.trainer.current_epoch - global_step = self.trainer.global_step - has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step - - global_step += 1 - if not has_reached_max_steps: - current_epoch += 1 - - model = self.trainer.lightning_module - - checkpoint = { - 'epoch': current_epoch, - 'global_step': global_step, - 'pytorch-lightning_version': pytorch_lightning.__version__, - 'state_dict': model.state_dict(), - } - - if not weights_only: - # dump callbacks - checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint) - - optimizer_states = [] - for i, optimizer in enumerate(self.trainer.optimizers): - # Rely on accelerator to dump optimizer state - optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) - optimizer_states.append(optimizer_state) - - checkpoint['optimizer_states'] = optimizer_states - - # dump lr schedulers - lr_schedulers = [] - for scheduler in self.trainer.lr_schedulers: - lr_schedulers.append(scheduler['scheduler'].state_dict()) - checkpoint['lr_schedulers'] = lr_schedulers - - # dump amp scaling - if ( - self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU - and self.trainer.scaler is not None - ): - checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() - elif self.trainer.amp_backend == AMPType.APEX: - checkpoint['amp_scaling_state'] = amp.state_dict() - - # dump hyper-parameters - if model.hparams: - if hasattr(model, '_hparams_name'): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name - # dump arguments - if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) - else: - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) - - # give the model a chance to dump a few things - model.on_save_checkpoint(checkpoint) - if self.trainer.datamodule is not None: - self.trainer.datamodule.on_save_checkpoint(checkpoint) - - return checkpoint - def hpc_load(self, checkpoint_path: str, on_gpu: bool): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. @@ -379,6 +280,9 @@ def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_' return max(ckpt_vs) + def dump_checkpoint(self, weights_only: bool = False) -> dict: + return dump_checkpoint(self.trainer, weights_only) + def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: """Get path of maximum-epoch checkpoint in the folder.""" @@ -386,13 +290,14 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: ckpt_number = max_suffix if max_suffix is not None else 0 return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt' - def save_checkpoint(self, filepath, weights_only: bool = False): + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: filepath: write-target file's path weights_only: saving model weights only """ + self.trainer.accelerator.save_checkpoint(self.trainer, filepath, weights_only) # dump states as a checkpoint dictionary object checkpoint = self.dump_checkpoint(weights_only) if self.trainer.is_global_zero: diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index e94934020107d..2df6dcadab2ca 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -20,6 +20,15 @@ import fsspec import torch +import pytorch_lightning as pl +from pytorch_lightning.utilities import _APEX_AVAILABLE, _OMEGACONF_AVAILABLE, AMPType, DeviceType + +if _APEX_AVAILABLE: + from apex import amp + +if _OMEGACONF_AVAILABLE: + from omegaconf import Container + def load(path_or_url: Union[str, IO, Path], map_location=None): if not isinstance(path_or_url, (str, Path)): @@ -63,3 +72,92 @@ def atomic_save(checkpoint, filepath: str): torch.save(checkpoint, bytesbuffer) with fsspec.open(filepath, "wb") as f: f.write(bytesbuffer.getvalue()) + + +def dump_checkpoint(trainer: 'pl.Trainer', weights_only: bool = False) -> dict: + """Creating a model checkpoint dictionary object from various component states. + + Args: + weights_only: saving model weights only + + Return: + structured dictionary: { + 'epoch': training epoch + 'global_step': training global step + 'pytorch-lightning_version': PyTorch Lightning's version + 'callbacks': "callback specific state"[] # if not weights_only + 'optimizer_states': "PT optim's state_dict"[] # if not weights_only + 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only + 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp + 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp + 'state_dict': Model's state_dict (e.g. network weights) + CHECKPOINT_HYPER_PARAMS_NAME: + CHECKPOINT_HYPER_PARAMS_KEY: + CHECKPOINT_HYPER_PARAMS_TYPE: + something_cool_i_want_to_save: anything you define through model.on_save_checkpoint + LightningDataModule.__class__.__name__: pl DataModule's state + } + """ + + # dump epoch/global_step/pytorch-lightning_version + current_epoch = trainer.current_epoch + global_step = trainer.global_step + has_reached_max_steps = trainer.max_steps and trainer.max_steps <= global_step + + global_step += 1 + if not has_reached_max_steps: + current_epoch += 1 + + model = trainer.lightning_module + + checkpoint = { + 'epoch': current_epoch, + 'global_step': global_step, + 'pytorch-lightning_version': pl.__version__, + 'state_dict': model.state_dict(), + } + + if not weights_only: + # dump callbacks + checkpoint['callbacks'] = trainer.on_save_checkpoint(checkpoint) + + optimizer_states = [] + for i, optimizer in enumerate(trainer.optimizers): + # Rely on accelerator to dump optimizer state + optimizer_state = trainer.accelerator.optimizer_state(optimizer) + optimizer_states.append(optimizer_state) + + checkpoint['optimizer_states'] = optimizer_states + + # dump lr schedulers + lr_schedulers = [] + for scheduler in trainer.lr_schedulers: + lr_schedulers.append(scheduler['scheduler'].state_dict()) + checkpoint['lr_schedulers'] = lr_schedulers + + # dump amp scaling + if ( + trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU + and trainer.scaler is not None + ): + checkpoint['native_amp_scaling_state'] = trainer.scaler.state_dict() + elif trainer.amp_backend == AMPType.APEX: + checkpoint['amp_scaling_state'] = amp.state_dict() + + # dump hyper-parameters + if model.hparams: + if hasattr(model, '_hparams_name'): + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name + # dump arguments + if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + else: + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + + # give the model a chance to dump a few things + model.on_save_checkpoint(checkpoint) + if trainer.datamodule is not None: + trainer.datamodule.on_save_checkpoint(checkpoint) + + return checkpoint From 3b4b681604b37e21f733233c2cf67ac9b928fd41 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Mar 2021 16:23:09 +0000 Subject: [PATCH 2/5] update --- .../trainer/connectors/checkpoint_connector.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 77f511808bf12..f005676789bdb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -298,20 +298,3 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: weights_only: saving model weights only """ self.trainer.accelerator.save_checkpoint(self.trainer, filepath, weights_only) - # dump states as a checkpoint dictionary object - checkpoint = self.dump_checkpoint(weights_only) - if self.trainer.is_global_zero: - # write the checkpoint dictionary on the file - - if self.trainer.training_type_plugin: - checkpoint = self.trainer.training_type_plugin.on_save(checkpoint) - try: - atomic_save(checkpoint, filepath) - except AttributeError as err: - if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn( - 'Warning, `hyper_parameters` dropped from checkpoint.' - f' An attribute is not picklable {err}' - ) - atomic_save(checkpoint, filepath) From bbfaa6d7f0af2e3b2ead9218d6e76f80cc6f8881 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 08:46:11 +0100 Subject: [PATCH 3/5] remove trainer for checkpoint --- pytorch_lightning/accelerators/accelerator.py | 6 +- .../plugins/training_type/tpu_spawn.py | 8 +- .../training_type/training_type_plugin.py | 8 +- .../connectors/checkpoint_connector.py | 104 +++++++++++++++++- pytorch_lightning/utilities/cloud_io.py | 98 ----------------- 5 files changed, 106 insertions(+), 118 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 225016dfdf6b0..f4c068f298e89 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,7 +17,6 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -import pytorch_lightning as pl from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin @@ -468,6 +467,5 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: ) self.setup_precision_plugin(plugin) - def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None: - # dump states as a checkpoint dictionary object - self.training_type_plugin.save_checkpoint(trainer, filepath, weights_only) + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + self.training_type_plugin.save_checkpoint(checkpoint, filepath) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6dbeddcff8213..aee2b8914b579 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -19,13 +19,11 @@ import torch import torch.multiprocessing as mp -import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn -from pytorch_lightning.utilities.cloud_io import dump_checkpoint from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -298,7 +296,7 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.lightning_module.predict_step(*args, **kwargs) - def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -306,7 +304,5 @@ def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = filepath: write-target file's path weights_only: saving model weights only """ - # dump states as a checkpoint dictionary object - _checkpoint = dump_checkpoint(trainer, weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment - self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 97250dd5917fc..252fb205d16f5 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -19,12 +19,11 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save, dump_checkpoint +from pytorch_lightning.utilities.cloud_io import atomic_save if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer @@ -196,10 +195,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: """ return False - def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: # dump states as a checkpoint dictionary object - checkpoint = dump_checkpoint(trainer, weights_only) - if trainer.is_global_zero: + if self.is_global_zero: # write the checkpoint dictionary on the file checkpoint = self.on_save(checkpoint) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f005676789bdb..819e248082ed9 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import re from pathlib import Path @@ -19,9 +18,17 @@ import torch +import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, DeviceType, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save, dump_checkpoint, get_filesystem +from pytorch_lightning.utilities import ( + _APEX_AVAILABLE, + _OMEGACONF_AVAILABLE, + AMPType, + DeviceType, + rank_zero_info, + rank_zero_warn, +) +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -29,6 +36,9 @@ if _APEX_AVAILABLE: from apex import amp +if _OMEGACONF_AVAILABLE: + from omegaconf import Container + class CheckpointConnector: @@ -281,7 +291,90 @@ def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_' return max(ckpt_vs) def dump_checkpoint(self, weights_only: bool = False) -> dict: - return dump_checkpoint(self.trainer, weights_only) + """Creating a model checkpoint dictionary object from various component states. + Args: + weights_only: saving model weights only + Return: + structured dictionary: { + 'epoch': training epoch + 'global_step': training global step + 'pytorch-lightning_version': PyTorch Lightning's version + 'callbacks': "callback specific state"[] # if not weights_only + 'optimizer_states': "PT optim's state_dict"[] # if not weights_only + 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only + 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp + 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp + 'state_dict': Model's state_dict (e.g. network weights) + CHECKPOINT_HYPER_PARAMS_NAME: + CHECKPOINT_HYPER_PARAMS_KEY: + CHECKPOINT_HYPER_PARAMS_TYPE: + something_cool_i_want_to_save: anything you define through model.on_save_checkpoint + LightningDataModule.__class__.__name__: pl DataModule's state + } + """ + + # dump epoch/global_step/pytorch-lightning_version + current_epoch = self.trainer.current_epoch + global_step = self.trainer.global_step + has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step + + global_step += 1 + if not has_reached_max_steps: + current_epoch += 1 + + model = self.trainer.lightning_module + + checkpoint = { + 'epoch': current_epoch, + 'global_step': global_step, + 'pytorch-lightning_version': pytorch_lightning.__version__, + 'state_dict': model.state_dict(), + } + + if not weights_only: + # dump callbacks + checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint) + + optimizer_states = [] + for i, optimizer in enumerate(self.trainer.optimizers): + # Rely on accelerator to dump optimizer state + optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) + optimizer_states.append(optimizer_state) + + checkpoint['optimizer_states'] = optimizer_states + + # dump lr schedulers + lr_schedulers = [] + for scheduler in self.trainer.lr_schedulers: + lr_schedulers.append(scheduler['scheduler'].state_dict()) + checkpoint['lr_schedulers'] = lr_schedulers + + # dump amp scaling + if ( + self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU + and self.trainer.scaler is not None + ): + checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() + elif self.trainer.amp_backend == AMPType.APEX: + checkpoint['amp_scaling_state'] = amp.state_dict() + + # dump hyper-parameters + if model.hparams: + if hasattr(model, '_hparams_name'): + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name + # dump arguments + if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + else: + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + + # give the model a chance to dump a few things + model.on_save_checkpoint(checkpoint) + if self.trainer.datamodule is not None: + self.trainer.datamodule.on_save_checkpoint(checkpoint) + + return checkpoint def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: """Get path of maximum-epoch checkpoint in the folder.""" @@ -297,4 +390,5 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: filepath: write-target file's path weights_only: saving model weights only """ - self.trainer.accelerator.save_checkpoint(self.trainer, filepath, weights_only) + _checkpoint = self.dump_checkpoint(weights_only) + self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 2df6dcadab2ca..e94934020107d 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -20,15 +20,6 @@ import fsspec import torch -import pytorch_lightning as pl -from pytorch_lightning.utilities import _APEX_AVAILABLE, _OMEGACONF_AVAILABLE, AMPType, DeviceType - -if _APEX_AVAILABLE: - from apex import amp - -if _OMEGACONF_AVAILABLE: - from omegaconf import Container - def load(path_or_url: Union[str, IO, Path], map_location=None): if not isinstance(path_or_url, (str, Path)): @@ -72,92 +63,3 @@ def atomic_save(checkpoint, filepath: str): torch.save(checkpoint, bytesbuffer) with fsspec.open(filepath, "wb") as f: f.write(bytesbuffer.getvalue()) - - -def dump_checkpoint(trainer: 'pl.Trainer', weights_only: bool = False) -> dict: - """Creating a model checkpoint dictionary object from various component states. - - Args: - weights_only: saving model weights only - - Return: - structured dictionary: { - 'epoch': training epoch - 'global_step': training global step - 'pytorch-lightning_version': PyTorch Lightning's version - 'callbacks': "callback specific state"[] # if not weights_only - 'optimizer_states': "PT optim's state_dict"[] # if not weights_only - 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp - 'state_dict': Model's state_dict (e.g. network weights) - CHECKPOINT_HYPER_PARAMS_NAME: - CHECKPOINT_HYPER_PARAMS_KEY: - CHECKPOINT_HYPER_PARAMS_TYPE: - something_cool_i_want_to_save: anything you define through model.on_save_checkpoint - LightningDataModule.__class__.__name__: pl DataModule's state - } - """ - - # dump epoch/global_step/pytorch-lightning_version - current_epoch = trainer.current_epoch - global_step = trainer.global_step - has_reached_max_steps = trainer.max_steps and trainer.max_steps <= global_step - - global_step += 1 - if not has_reached_max_steps: - current_epoch += 1 - - model = trainer.lightning_module - - checkpoint = { - 'epoch': current_epoch, - 'global_step': global_step, - 'pytorch-lightning_version': pl.__version__, - 'state_dict': model.state_dict(), - } - - if not weights_only: - # dump callbacks - checkpoint['callbacks'] = trainer.on_save_checkpoint(checkpoint) - - optimizer_states = [] - for i, optimizer in enumerate(trainer.optimizers): - # Rely on accelerator to dump optimizer state - optimizer_state = trainer.accelerator.optimizer_state(optimizer) - optimizer_states.append(optimizer_state) - - checkpoint['optimizer_states'] = optimizer_states - - # dump lr schedulers - lr_schedulers = [] - for scheduler in trainer.lr_schedulers: - lr_schedulers.append(scheduler['scheduler'].state_dict()) - checkpoint['lr_schedulers'] = lr_schedulers - - # dump amp scaling - if ( - trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU - and trainer.scaler is not None - ): - checkpoint['native_amp_scaling_state'] = trainer.scaler.state_dict() - elif trainer.amp_backend == AMPType.APEX: - checkpoint['amp_scaling_state'] = amp.state_dict() - - # dump hyper-parameters - if model.hparams: - if hasattr(model, '_hparams_name'): - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name - # dump arguments - if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) - else: - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) - - # give the model a chance to dump a few things - model.on_save_checkpoint(checkpoint) - if trainer.datamodule is not None: - trainer.datamodule.on_save_checkpoint(checkpoint) - - return checkpoint From 35605951ac66325f9fd628250694013417a21c83 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 12:44:20 +0100 Subject: [PATCH 4/5] switch back to master --- .../connectors/checkpoint_connector.py | 113 +++++++++--------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 819e248082ed9..8b602fa6caa69 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import re from pathlib import Path @@ -235,65 +236,12 @@ def hpc_save(self, folderpath: str, logger): return filepath - def hpc_load(self, checkpoint_path: str, on_gpu: bool): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. - """ - - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - # acquire the model - model = self.trainer.lightning_module - - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - if self.trainer.root_gpu is not None: - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - # call hpc specific hook - model.on_hpc_load(checkpoint) - - def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: - """List up files in `dir_path` with `name_key`, then yield maximum suffix number. - - Args: - dir_path: path of directory which may contain files whose name include `name_key` - name_key: file name prefix - - Returns: - None if no-corresponding-file else maximum suffix number - """ - - # check directory existence - fs = get_filesystem(dir_path) - if not fs.exists(dir_path): - return None - - # check corresponding file existence - files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)] - files = [x for x in files if name_key in x] - if len(files) == 0: - return None - - # extract suffix number - ckpt_vs = [] - for name in files: - name = name.split(name_key)[-1] - name = re.sub('[^0-9]', '', name) - ckpt_vs.append(int(name)) - - return max(ckpt_vs) - def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. + Args: weights_only: saving model weights only + Return: structured dictionary: { 'epoch': training epoch @@ -376,6 +324,61 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def hpc_load(self, checkpoint_path: str, on_gpu: bool): + """ + Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. + All restored states are listed in return value description of `dump_checkpoint`. + """ + + # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # acquire the model + model = self.trainer.lightning_module + + # restore model and datamodule state + self.restore_model_state(model, checkpoint) + + if self.trainer.root_gpu is not None: + model.cuda(self.trainer.root_gpu) + + # restore training state + self.restore_training_state(checkpoint) + + # call hpc specific hook + model.on_hpc_load(checkpoint) + + def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: + """List up files in `dir_path` with `name_key`, then yield maximum suffix number. + + Args: + dir_path: path of directory which may contain files whose name include `name_key` + name_key: file name prefix + + Returns: + None if no-corresponding-file else maximum suffix number + """ + + # check directory existence + fs = get_filesystem(dir_path) + if not fs.exists(dir_path): + return None + + # check corresponding file existence + files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)] + files = [x for x in files if name_key in x] + if len(files) == 0: + return None + + # extract suffix number + ckpt_vs = [] + for name in files: + name = name.split(name_key)[-1] + name = re.sub('[^0-9]', '', name) + ckpt_vs.append(int(name)) + + return max(ckpt_vs) + def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: """Get path of maximum-epoch checkpoint in the folder.""" From d5aa78ea381b6ad2e6bdf1c038b245071479262c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 29 Mar 2021 14:54:35 +0100 Subject: [PATCH 5/5] Update pytorch_lightning/plugins/training_type/training_type_plugin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- .../plugins/training_type/training_type_plugin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 252fb205d16f5..ca097c32513c6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -198,10 +198,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: # dump states as a checkpoint dictionary object if self.is_global_zero: - # write the checkpoint dictionary on the file - checkpoint = self.on_save(checkpoint) try: + # write the checkpoint dictionary on the file atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: