Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed ZeRO Update #6546

Merged
merged 70 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
2d3f617
Add context to call hook to handle all modules defined within the hook
Mar 10, 2021
99495e8
Expose some additional parameters
Mar 10, 2021
c3aac67
Added docs, exposed parameters
Mar 11, 2021
340f817
Make sure we only configure if necessary
Mar 11, 2021
ae02102
Merge branch 'master' into feat/ds_update
Mar 11, 2021
f192afc
Setup activation checkpointing regardless, saves the user having to d…
Mar 12, 2021
a2784a4
Add some tests that fail currently
Mar 12, 2021
b0dab3d
update
tchaton Mar 15, 2021
0c44f05
update
tchaton Mar 15, 2021
26655d7
update
tchaton Mar 15, 2021
ac19f36
add tests
tchaton Mar 16, 2021
d273393
change docstring
tchaton Mar 16, 2021
c91d128
resolve accumulate_grad_batches
tchaton Mar 16, 2021
959d7b7
resolve flake8
tchaton Mar 16, 2021
f0cb6e7
Update DeepSpeed to use latest version, add some comments
Mar 17, 2021
914de86
add metrics
tchaton Mar 17, 2021
5d16c74
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 17, 2021
712814c
update
tchaton Mar 17, 2021
a1644c6
Small formatting fixes, clean up some code
Mar 17, 2021
64f624f
Few cleanups
Mar 17, 2021
89fbbcb
No need for default state
Mar 18, 2021
701d417
Fix tests, add some boilerplate that should move eventually
Mar 18, 2021
270d6ed
Add hook removal
Mar 22, 2021
2b71ed8
Merge branch 'master' into feat/ds_update
Mar 22, 2021
a236ff0
Add a context manager to handle hook
Mar 23, 2021
e1f865e
Small naming cleanup
Mar 25, 2021
80fb792
wip
tchaton Mar 26, 2021
d621b1f
Merge branch 'master' into feat/ds_update
tchaton Mar 26, 2021
1de2bcd
move save_checkpoint responsability to accelerator
tchaton Mar 26, 2021
90d6e03
resolve flake8
tchaton Mar 26, 2021
b6361b8
add BC
tchaton Mar 26, 2021
924d9e2
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 26, 2021
6acaccb
Change recommended scale to 16
Mar 29, 2021
f7a373e
Merge branch 'master' into feat/ds_update
tchaton Mar 30, 2021
68b8a43
resolve flake8
tchaton Mar 30, 2021
a7dcb7b
update test
tchaton Mar 30, 2021
08df0b5
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
6b08478
update install
tchaton Mar 30, 2021
45a49c5
update
tchaton Mar 30, 2021
a8da299
update test
tchaton Mar 30, 2021
99f1d96
update
tchaton Mar 30, 2021
89601d8
update
tchaton Mar 30, 2021
eb1495e
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
389c60b
update test
tchaton Mar 30, 2021
de5f358
resolve flake8
tchaton Mar 30, 2021
301b1aa
update
tchaton Mar 30, 2021
b9542ae
update
tchaton Mar 30, 2021
48c0950
update on comments
tchaton Mar 30, 2021
c230407
Push
Mar 30, 2021
783265f
pull
Mar 30, 2021
c8f79f9
Update pytorch_lightning/plugins/training_type/deepspeed.py
tchaton Mar 30, 2021
61378de
Update pytorch_lightning/plugins/training_type/deepspeed.py
tchaton Mar 30, 2021
45c9569
update
tchaton Mar 30, 2021
deb2ea2
Apply suggestions from code review
SeanNaren Mar 30, 2021
122e911
Swap to using world size defined by plugin
Mar 30, 2021
dfb403b
update
tchaton Mar 30, 2021
9bd2821
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
066e0f0
update todo
tchaton Mar 30, 2021
d41284e
Remove deepspeed from extra, keep it in the base cuda docker install
Mar 30, 2021
0c9836c
Push
Mar 30, 2021
d1c511e
pull
Mar 30, 2021
67d31fa
update
tchaton Mar 30, 2021
e65aaf3
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
1740eed
update
tchaton Mar 30, 2021
300f3aa
update
tchaton Mar 30, 2021
40b1cc6
update
tchaton Mar 30, 2021
603caf1
Minor changes
carmocca Mar 30, 2021
62f67e8
duplicate
Borda Mar 30, 2021
5786c4b
format
Borda Mar 30, 2021
83e1343
format2
Borda Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ RUN \
rm -rf apex

RUN \
# install DeepSpeed from source.
# todo: swap to pypi release once DeepSpeed releases a new version >= 0.3.10
pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb
pip install deepspeed>=0.3.13

RUN \
# Show what we have
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def results(self) -> Any:
return self.training_type_plugin.results

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
def model_sharded_context(self) -> Generator[None, None, None]:
"""
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly - useful for extremely large models. Can save memory and
Expand Down Expand Up @@ -511,3 +511,6 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)
216 changes: 194 additions & 22 deletions pytorch_lightning/plugins/training_type/deepspeed.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch.nn import Module
Expand All @@ -25,6 +25,7 @@
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load

if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -197,6 +198,43 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
return False

def restore_model_state_from_ckpt_path(self,
ckpt_path: str,
map_location=lambda storage, loc: storage) -> Tuple[Dict, bool]:
"""
This function is used to load and restore the model state.

Args:
ckpt_path: Path to a checkpoint
map_location: lambda function to map checkpoint location

Return
checkpoint: Return loaded checkpoint
bool: Wether to load optimizer / lr_schedulers states from checkpoint

"""
ckpt = pl_load(ckpt_path, map_location=map_location)
# restore datamodule states
if self.lightning_module.trainer.datamodule is not None:
self.lightning_module.trainer.datamodule.on_load_checkpoint(ckpt)

# hook: give user access to checkpoint if needed.
self.lightning_module.on_load_checkpoint(ckpt)
self.lightning_module.load_state_dict(ckpt['state_dict'])
return ckpt, True

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
"""
Provide a hook to count optimizer step calls.

Args:
total_batch_idx: Total number of batches seen for training
current_global_step: Current number of optimizer step calls

Returns: New optimizer step calls
"""
return current_global_step + 1

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Expand Down
22 changes: 9 additions & 13 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,17 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
return False

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path(
checkpoint_path, map_location=lambda storage, loc: storage
)

# acquire the model
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if on_gpu:
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint)
self.restore_training_state(checkpoint, load_optimizer_states)

rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
return True
Expand All @@ -123,15 +120,15 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])

def restore_training_state(self, checkpoint):
def restore_training_state(self, checkpoint, load_optimizer_states: bool = True):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
"""
# validation
if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint:
if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint):
raise KeyError(
'Trying to restore training state but checkpoint contains only the model.'
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
Expand Down Expand Up @@ -177,6 +174,9 @@ def restore_training_state(self, checkpoint):
" consider using an end of epoch checkpoint."
)

if not load_optimizer_states:
return

# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
Expand Down Expand Up @@ -238,10 +238,8 @@ def hpc_save(self, folderpath: str, logger):

def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.

Args:
weights_only: saving model weights only

Return:
structured dictionary: {
'epoch': training epoch
Expand Down Expand Up @@ -350,11 +348,9 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):

def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.

Args:
dir_path: path of directory which may contain files whose name include `name_key`
name_key: file name prefix

Returns:
None if no-corresponding-file else maximum suffix number
"""
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
Expand Down Expand Up @@ -982,12 +981,13 @@ def __load_ckpt_weights(
)

# only one process running at this point for TPUs, as spawn isn't triggered yet
if self._device_type != DeviceType.TPU:
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])

self.training_type_plugin.restore_model_state_from_ckpt_path(
ckpt_path, map_location=lambda storage, loc: storage
)
return ckpt_path

def predict(
Expand Down Expand Up @@ -1086,10 +1086,14 @@ def call_setup_hook(self, model: LightningModule) -> None:
def call_configure_sharded_model(self, model: LightningModule) -> None:
# Call configure sharded model hook if accelerator requests. In some cases
# we will not call the hook; the hook has initialized the sharded model for example.
if self.accelerator.call_configure_sharded_model_hook:

# used on the model if the user re-create a trainer with resume_from_checkpoint
tchaton marked this conversation as resolved.
Show resolved Hide resolved
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
model.configure_sharded_model()
self.configure_sharded_model(model)
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False

def call_teardown_hook(self, model: LightningModule) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,9 @@ def increment_accumulated_grad_global_step(self):

# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step += 1
self.trainer.global_step = self.trainer.accelerator.update_global_step(
self.trainer.total_batch_idx, self.trainer.global_step
)

def _accumulated_batches_reached(self):
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")

_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
Expand Down
Loading