Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Move save_function to accelerator 1/n [DeepSpeed] #6689

Merged
merged 5 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
Expand Down Expand Up @@ -466,3 +467,7 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
' It will be removed in v1.5.'
)
self.setup_precision_plugin(plugin)

def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None:
# dump states as a checkpoint dictionary object
self.training_type_plugin.save_checkpoint(trainer, filepath, weights_only)
9 changes: 5 additions & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import torch
import torch.multiprocessing as mp

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import dump_checkpoint
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down Expand Up @@ -106,8 +108,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)

# replace trainer save_checkpoint to use `xm.save`
trainer.save_checkpoint = self.save_checkpoint
self.barrier("pre-run-stage")

results = trainer.run_stage()
Expand Down Expand Up @@ -298,14 +298,15 @@ def test_step(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.lightning_module.predict_step(*args, **kwargs)

def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
trainer: PyTorch Lightning Trainer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trainer: PyTorch Lightning Trainer
checkpoint: dict containing model and trainer state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

filepath: write-target file's path
weights_only: saving model weights only
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
"""
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
_checkpoint = dump_checkpoint(trainer, weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
21 changes: 21 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, dump_checkpoint

if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -192,3 +195,21 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False

def save_checkpoint(self, trainer: 'pl.Trainer', filepath, weights_only: bool = False) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# dump states as a checkpoint dictionary object
checkpoint = dump_checkpoint(trainer, weights_only)
if trainer.is_global_zero:
# write the checkpoint dictionary on the file

checkpoint = self.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
126 changes: 7 additions & 119 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,16 @@

import torch

import pytorch_lightning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_OMEGACONF_AVAILABLE,
AMPType,
DeviceType,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, DeviceType, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, dump_checkpoint, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _APEX_AVAILABLE:
from apex import amp

if _OMEGACONF_AVAILABLE:
from omegaconf import Container


class CheckpointConnector:

Expand Down Expand Up @@ -236,94 +225,6 @@ def hpc_save(self, folderpath: str, logger):

return filepath

def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.

Args:
weights_only: saving model weights only

Return:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': PyTorch Lightning's version
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp
'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp
'state_dict': Model's state_dict (e.g. network weights)
CHECKPOINT_HYPER_PARAMS_NAME:
CHECKPOINT_HYPER_PARAMS_KEY:
CHECKPOINT_HYPER_PARAMS_TYPE:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__name__: pl DataModule's state
}
"""

# dump epoch/global_step/pytorch-lightning_version
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step

global_step += 1
if not has_reached_max_steps:
current_epoch += 1

model = self.trainer.lightning_module

checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
}

if not weights_only:
# dump callbacks
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)

optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint['optimizer_states'] = optimizer_states

# dump lr schedulers
lr_schedulers = []
for scheduler in self.trainer.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers

# dump amp scaling
if (
self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU
and self.trainer.scaler is not None
):
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

# dump hyper-parameters
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

# give the model a chance to dump a few things
model.on_save_checkpoint(checkpoint)
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)

return checkpoint

def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
Expand Down Expand Up @@ -379,34 +280,21 @@ def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_'

return max(ckpt_vs)

def dump_checkpoint(self, weights_only: bool = False) -> dict:
return dump_checkpoint(self.trainer, weights_only)

def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""

max_suffix = self.max_ckpt_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'

def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
checkpoint = self.dump_checkpoint(weights_only)
if self.trainer.is_global_zero:
# write the checkpoint dictionary on the file

if self.trainer.training_type_plugin:
checkpoint = self.trainer.training_type_plugin.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
self.trainer.accelerator.save_checkpoint(self.trainer, filepath, weights_only)
98 changes: 98 additions & 0 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
import fsspec
import torch

import pytorch_lightning as pl
from pytorch_lightning.utilities import _APEX_AVAILABLE, _OMEGACONF_AVAILABLE, AMPType, DeviceType

if _APEX_AVAILABLE:
from apex import amp

if _OMEGACONF_AVAILABLE:
from omegaconf import Container


def load(path_or_url: Union[str, IO, Path], map_location=None):
if not isinstance(path_or_url, (str, Path)):
Expand Down Expand Up @@ -63,3 +72,92 @@ def atomic_save(checkpoint, filepath: str):
torch.save(checkpoint, bytesbuffer)
with fsspec.open(filepath, "wb") as f:
f.write(bytesbuffer.getvalue())


def dump_checkpoint(trainer: 'pl.Trainer', weights_only: bool = False) -> dict:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Creating a model checkpoint dictionary object from various component states.

Args:
weights_only: saving model weights only

Return:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': PyTorch Lightning's version
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp
'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp
'state_dict': Model's state_dict (e.g. network weights)
CHECKPOINT_HYPER_PARAMS_NAME:
CHECKPOINT_HYPER_PARAMS_KEY:
CHECKPOINT_HYPER_PARAMS_TYPE:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__name__: pl DataModule's state
}
"""

# dump epoch/global_step/pytorch-lightning_version
current_epoch = trainer.current_epoch
global_step = trainer.global_step
has_reached_max_steps = trainer.max_steps and trainer.max_steps <= global_step

global_step += 1
if not has_reached_max_steps:
current_epoch += 1

model = trainer.lightning_module

checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pl.__version__,
'state_dict': model.state_dict(),
}

if not weights_only:
# dump callbacks
checkpoint['callbacks'] = trainer.on_save_checkpoint(checkpoint)

optimizer_states = []
for i, optimizer in enumerate(trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint['optimizer_states'] = optimizer_states

# dump lr schedulers
lr_schedulers = []
for scheduler in trainer.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers

# dump amp scaling
if (
trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU
and trainer.scaler is not None
):
checkpoint['native_amp_scaling_state'] = trainer.scaler.state_dict()
elif trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

# dump hyper-parameters
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

# give the model a chance to dump a few things
model.on_save_checkpoint(checkpoint)
if trainer.datamodule is not None:
trainer.datamodule.on_save_checkpoint(checkpoint)

return checkpoint