diff --git a/CHANGELOG.md b/CHANGELOG.md index d3c8f5c4e1672..84db9b385009e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) +- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474)) + + ### Deprecated diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a05f0c31c9abd..8bf34229a9742 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -420,6 +420,12 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """ return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer) + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """ + Returns state of model. Allows for syncing/collating model state from processes in custom plugins. + """ + return self.training_type_plugin.lightning_module_state_dict() + def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.on_save(checkpoint) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index ede5717258040..4abbb4dbb7c05 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TypeVar, Union import torch +from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -241,6 +242,11 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> """ return current_global_step + 1 + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """Returns model state.""" + model = self.lightning_module + return model.state_dict() + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1f61b33a74b9c..1181c4f3efd1e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -273,7 +273,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'epoch': current_epoch, 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, - 'state_dict': model.state_dict(), + 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } if not weights_only: