diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1dcd541ca0610..f4c068f298e89 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -466,3 +466,6 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: ' It will be removed in v1.5.' ) self.setup_precision_plugin(plugin) + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + self.training_type_plugin.save_checkpoint(checkpoint, filepath) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a8706d54cb5c9..aee2b8914b579 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -106,8 +106,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.accelerator.setup_optimizers(trainer) trainer.precision_plugin.connect(self._model, None, None) - # replace trainer save_checkpoint to use `xm.save` - trainer.save_checkpoint = self.save_checkpoint self.barrier("pre-run-stage") results = trainer.run_stage() @@ -298,14 +296,13 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.lightning_module.predict_step(*args, **kwargs) - def save_checkpoint(self, filepath, weights_only: bool = False): + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: + trainer: PyTorch Lightning Trainer filepath: write-target file's path weights_only: saving model weights only """ - # dump states as a checkpoint dictionary object - _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment - self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 08dca63a7c925..ca097c32513c6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -22,6 +22,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer @@ -192,3 +194,19 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. """ return False + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + # dump states as a checkpoint dictionary object + if self.is_global_zero: + checkpoint = self.on_save(checkpoint) + try: + # write the checkpoint dictionary on the file + atomic_save(checkpoint, filepath) + except AttributeError as err: + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + rank_zero_warn( + 'Warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' + ) + atomic_save(checkpoint, filepath) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 60c76b70bba50..8b602fa6caa69 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -386,27 +386,12 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: ckpt_number = max_suffix if max_suffix is not None else 0 return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt' - def save_checkpoint(self, filepath, weights_only: bool = False): + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: filepath: write-target file's path weights_only: saving model weights only """ - # dump states as a checkpoint dictionary object - checkpoint = self.dump_checkpoint(weights_only) - if self.trainer.is_global_zero: - # write the checkpoint dictionary on the file - - if self.trainer.training_type_plugin: - checkpoint = self.trainer.training_type_plugin.on_save(checkpoint) - try: - atomic_save(checkpoint, filepath) - except AttributeError as err: - if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn( - 'Warning, `hyper_parameters` dropped from checkpoint.' - f' An attribute is not picklable {err}' - ) - atomic_save(checkpoint, filepath) + _checkpoint = self.dump_checkpoint(weights_only) + self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)