Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Move save_function to accelerator 1/n [DeepSpeed] #6689

Merged
merged 5 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,6 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
' It will be removed in v1.5.'
)
self.setup_precision_plugin(plugin)

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
self.training_type_plugin.save_checkpoint(checkpoint, filepath)
9 changes: 3 additions & 6 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)

# replace trainer save_checkpoint to use `xm.save`
trainer.save_checkpoint = self.save_checkpoint
self.barrier("pre-run-stage")

results = trainer.run_stage()
Expand Down Expand Up @@ -298,14 +296,13 @@ def test_step(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.lightning_module.predict_step(*args, **kwargs)

def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
trainer: PyTorch Lightning Trainer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trainer: PyTorch Lightning Trainer
checkpoint: dict containing model and trainer state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

filepath: write-target file's path
weights_only: saving model weights only
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
"""
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
19 changes: 19 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save

if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -192,3 +194,20 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
# dump states as a checkpoint dictionary object
if self.is_global_zero:
# write the checkpoint dictionary on the file

checkpoint = self.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
21 changes: 3 additions & 18 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,27 +386,12 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
ckpt_number = max_suffix if max_suffix is not None else 0
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'

def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
checkpoint = self.dump_checkpoint(weights_only)
if self.trainer.is_global_zero:
# write the checkpoint dictionary on the file

if self.trainer.training_type_plugin:
checkpoint = self.trainer.training_type_plugin.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)