Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed ZeRO Update #6546

Merged
merged 70 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
2d3f617
Add context to call hook to handle all modules defined within the hook
Mar 10, 2021
99495e8
Expose some additional parameters
Mar 10, 2021
c3aac67
Added docs, exposed parameters
Mar 11, 2021
340f817
Make sure we only configure if necessary
Mar 11, 2021
ae02102
Merge branch 'master' into feat/ds_update
Mar 11, 2021
f192afc
Setup activation checkpointing regardless, saves the user having to d…
Mar 12, 2021
a2784a4
Add some tests that fail currently
Mar 12, 2021
b0dab3d
update
tchaton Mar 15, 2021
0c44f05
update
tchaton Mar 15, 2021
26655d7
update
tchaton Mar 15, 2021
ac19f36
add tests
tchaton Mar 16, 2021
d273393
change docstring
tchaton Mar 16, 2021
c91d128
resolve accumulate_grad_batches
tchaton Mar 16, 2021
959d7b7
resolve flake8
tchaton Mar 16, 2021
f0cb6e7
Update DeepSpeed to use latest version, add some comments
Mar 17, 2021
914de86
add metrics
tchaton Mar 17, 2021
5d16c74
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 17, 2021
712814c
update
tchaton Mar 17, 2021
a1644c6
Small formatting fixes, clean up some code
Mar 17, 2021
64f624f
Few cleanups
Mar 17, 2021
89fbbcb
No need for default state
Mar 18, 2021
701d417
Fix tests, add some boilerplate that should move eventually
Mar 18, 2021
270d6ed
Add hook removal
Mar 22, 2021
2b71ed8
Merge branch 'master' into feat/ds_update
Mar 22, 2021
a236ff0
Add a context manager to handle hook
Mar 23, 2021
e1f865e
Small naming cleanup
Mar 25, 2021
80fb792
wip
tchaton Mar 26, 2021
d621b1f
Merge branch 'master' into feat/ds_update
tchaton Mar 26, 2021
1de2bcd
move save_checkpoint responsability to accelerator
tchaton Mar 26, 2021
90d6e03
resolve flake8
tchaton Mar 26, 2021
b6361b8
add BC
tchaton Mar 26, 2021
924d9e2
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 26, 2021
6acaccb
Change recommended scale to 16
Mar 29, 2021
f7a373e
Merge branch 'master' into feat/ds_update
tchaton Mar 30, 2021
68b8a43
resolve flake8
tchaton Mar 30, 2021
a7dcb7b
update test
tchaton Mar 30, 2021
08df0b5
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
6b08478
update install
tchaton Mar 30, 2021
45a49c5
update
tchaton Mar 30, 2021
a8da299
update test
tchaton Mar 30, 2021
99f1d96
update
tchaton Mar 30, 2021
89601d8
update
tchaton Mar 30, 2021
eb1495e
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
389c60b
update test
tchaton Mar 30, 2021
de5f358
resolve flake8
tchaton Mar 30, 2021
301b1aa
update
tchaton Mar 30, 2021
b9542ae
update
tchaton Mar 30, 2021
48c0950
update on comments
tchaton Mar 30, 2021
c230407
Push
Mar 30, 2021
783265f
pull
Mar 30, 2021
c8f79f9
Update pytorch_lightning/plugins/training_type/deepspeed.py
tchaton Mar 30, 2021
61378de
Update pytorch_lightning/plugins/training_type/deepspeed.py
tchaton Mar 30, 2021
45c9569
update
tchaton Mar 30, 2021
deb2ea2
Apply suggestions from code review
SeanNaren Mar 30, 2021
122e911
Swap to using world size defined by plugin
Mar 30, 2021
dfb403b
update
tchaton Mar 30, 2021
9bd2821
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
066e0f0
update todo
tchaton Mar 30, 2021
d41284e
Remove deepspeed from extra, keep it in the base cuda docker install
Mar 30, 2021
0c9836c
Push
Mar 30, 2021
d1c511e
pull
Mar 30, 2021
67d31fa
update
tchaton Mar 30, 2021
e65aaf3
Merge branch 'feat/ds_update' of https://github.com/PyTorchLightning/…
tchaton Mar 30, 2021
1740eed
update
tchaton Mar 30, 2021
300f3aa
update
tchaton Mar 30, 2021
40b1cc6
update
tchaton Mar 30, 2021
603caf1
Minor changes
carmocca Mar 30, 2021
62f67e8
duplicate
Borda Mar 30, 2021
5786c4b
format
Borda Mar 30, 2021
83e1343
format2
Borda Mar 30, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ RUN \
rm -rf apex

RUN \
# install DeepSpeed from source.
# todo: swap to pypi release once DeepSpeed releases a new version >= 0.3.10
pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb
pip install deepspeed>=0.3.13

RUN \
# Show what we have
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def results(self) -> Any:
return self.training_type_plugin.results

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
def model_sharded_context(self) -> Generator[None, None, None]:
"""
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly - useful for extremely large models. Can save memory and
Expand Down Expand Up @@ -511,3 +511,6 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)
207 changes: 185 additions & 22 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import json
import logging
import os
from collections import OrderedDict
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand All @@ -37,6 +37,16 @@
import deepspeed


def remove_module_hooks(model: torch.nn.Module) -> None:
for module in model.modules():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
module._backward_hooks = OrderedDict()
module._is_full_backward_hook = None
module._forward_hooks = OrderedDict()
module._forward_pre_hooks = OrderedDict()
module._state_dict_hooks = OrderedDict()
module._load_state_dict_pre_hooks = OrderedDict()


class LightningDeepSpeedModule(_LightningModuleWrapperBase):

def __init__(self, pl_module: LightningModule, precision: int):
Expand Down Expand Up @@ -67,6 +77,8 @@ def __init__(
zero_optimization: bool = True,
stage: int = 2,
cpu_offload: bool = False,
cpu_offload_params: bool = False,
cpu_offload_use_pin_memory: bool = False,
contiguous_gradients: bool = True,
overlap_comm: bool = True,
allgather_partitions: bool = True,
Expand All @@ -80,10 +92,14 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
loss_scale: float = 0,
initial_scale_power: int = 32,
initial_scale_power: int = 16,
loss_scale_window: int = 1000,
hysteresis: int = 2,
min_loss_scale: int = 1
min_loss_scale: int = 1,
partition_activations: bool = False,
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
) -> None:
"""

Expand All @@ -106,6 +122,10 @@ def __init__(

cpu_offload: Enable offloading optimizer memory and computation to CPU

cpu_offload_params: When using ZeRO stage 3, offload parameters to CPU

cpu_offload_use_pin_memory: When using ZeRO stage 3, pin memory on CPU

contiguous_gradients: Copies gradients to a continuous buffer as they are produced.
Avoids memory fragmentation during backwards. Useful when training large models. (default: True)

Expand Down Expand Up @@ -144,6 +164,16 @@ def __init__(

min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000)

partition_activations: Enables partition activation when used with ZeRO stage 3.
Still requires you to wrap your forward functions in deepspeed.checkpointing.checkpoint.
See https://www.deepspeed.ai/tutorials/megatron/#deepspeed-activation-checkpoints-optional
tchaton marked this conversation as resolved.
Show resolved Hide resolved

cpu_checkpointing: Offloads partitioned activations to CPU if ``partition_activations`` is enabled

contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory.
Not supported by all models

synchronize_checkpoint_boundary: Insert ``torch.cuda.synchronize()`` at each checkpoint boundary.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
Expand All @@ -159,8 +189,14 @@ def __init__(
self.config = self._create_default_config(
zero_optimization,
zero_allow_untested_optimizer,
partition_activations=partition_activations,
cpu_checkpointing=cpu_checkpointing,
contiguous_memory_optimization=contiguous_memory_optimization,
synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
stage=stage,
cpu_offload=cpu_offload,
cpu_offload_params=cpu_offload_params,
cpu_offload_use_pin_memory=cpu_offload_use_pin_memory,
contiguous_gradients=contiguous_gradients,
overlap_comm=overlap_comm,
allgather_partitions=allgather_partitions,
Expand Down Expand Up @@ -200,9 +236,14 @@ def init_deepspeed(self):
self._format_config()
self._config_initialized = True

self._handle_gradient_accumulation_steps()

precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.on_gpu:
torch.cuda.set_device(self.root_device)

if self.lightning_module.trainer and self.lightning_module.trainer.training:
self._initialize_deepspeed_train(model)
else:
Expand All @@ -220,9 +261,11 @@ def _init_scheduler_optimizer(self):
optimizer = optimizers[0]
return optimizer, scheduler, optimizer_frequencies

@property
def zero_stage_3(self) -> bool:
return self.config.get('zero_optimization') and self.config.get('zero_optimization').get('stage') == 3

def _initialize_deepspeed_train(self, model):
if self.on_gpu:
torch.cuda.set_device(self.root_device)
optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
if "optimizer" not in self.config:
rank_zero_info(
Expand All @@ -232,28 +275,70 @@ def _initialize_deepspeed_train(self, model):
optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer()
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(local_rank=self.local_rank),
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=lightning_scheduler,
config_params=self.config,
)
self._set_deepspeed_activation_checkpointing()

# set optimizer for save/load, but deepspeed manages the specific optimizer logic
self.lightning_module.trainer.optimizers = [optimizer]
self.lightning_module.trainer.schedulers = [lr_scheduler]
self.model = model

@contextlib.contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True)
else:
model_parallel_context = super().model_sharded_context()

with model_parallel_context:
yield

def _set_deepspeed_activation_checkpointing(self):
if self.config.get('activation_checkpointing'):
checkpoint_config = self.config['activation_checkpointing']
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=checkpoint_config.get('partition_activations'),
contiguous_checkpointing=checkpoint_config.get('contiguous_checkpointing'),
checkpoint_in_cpu=checkpoint_config.get('checkpoint_in_cpu'),
profile=checkpoint_config.get('profile'),
)

def _initialize_deepspeed_inference(self, model):
# move the model to the correct device
self.model_to_device()

self.pre_configure_ddp()
self.model = DistributedDataParallel(
model,
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
# todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
if "optimizer" not in self.config:
rank_zero_info(
"You have not specified an optimizer or scheduler within the DeepSpeed config."
"Using `configure_optimizers` to define optimizer and scheduler."
)
optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer()
inference_config = {
# todo: this is required for DeepSpeed throughput timers
tchaton marked this conversation as resolved.
Show resolved Hide resolved
'train_micro_batch_size_per_gpu': 1,
}
if 'fp16' in self.config:
inference_config.update({"fp16": self.config["fp16"]})
if self.zero_stage_3:
inference_config.update({
"zero_allow_untested_optimizer": self.config['zero_allow_untested_optimizer'],
"zero_optimization": self.config['zero_optimization'],
})
# Remove all module hooks before initializing new model
remove_module_hooks(model)
model, _, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
lr_scheduler=lightning_scheduler,
config_params=inference_config,
model_parameters=[],
)
self.model = model

def configure_scheduler(self, lr_scheduler):
scheduler = _get_default_scheduler_config()
Expand Down Expand Up @@ -282,6 +367,13 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla
# internally, the engine has a reference to the optimizer already.
self.model.step(**kwargs)

def _handle_gradient_accumulation_steps(self):
if self.config.get("gradient_accumulation_steps") > 1:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1})
else:
self._original_accumulate_grad_batches = None

def _format_config(self):
if self.config is None:
raise MisconfigurationException(
Expand All @@ -300,14 +392,13 @@ def _format_batch_size_and_grad_accum_config(self):
if "train_micro_batch_size_per_gpu" not in self.config:
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
batch_size = self.lightning_module.train_dataloader().batch_size
batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size
self.config["train_micro_batch_size_per_gpu"] = batch_size
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val

def _format_precision_config(self):

amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
precision = self.lightning_module.trainer.accelerator_connector.precision
Expand All @@ -333,8 +424,80 @@ def _format_precision_config(self):
raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.")

def _create_default_config(
self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs
self, zero_optimization: bool, zero_allow_untested_optimizer: bool, partition_activations: bool,
cpu_checkpointing: bool, contiguous_memory_optimization: bool, synchronize_checkpoint_boundary: bool,
**zero_kwargs
) -> Dict:
cfg = {
'activation_checkpointing': {
"partition_activations": partition_activations,
"cpu_checkpointing": cpu_checkpointing,
"contiguous_memory_optimization": contiguous_memory_optimization,
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary
}
}
if zero_optimization:
return {"zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs}
return {}
cfg = {
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
"zero_optimization": zero_kwargs,
**cfg
}
return cfg

def _filepath_to_dir(self, filepath: str):
return os.path.dirname(filepath)

@property
def deepspeed_engine(self):
return self.model

def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
if torch.distributed.get_world_size() > 1 and self.zero_stage_3:
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
save_dir = self._filepath_to_dir(filepath)
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)

else:
super().save_checkpoint(checkpoint, filepath)

def restore_model_state_from_ckpt_path(self,
ckpt_path: str,
map_location=lambda storage, loc: storage) -> Tuple[Dict, bool]:
if torch.distributed.get_world_size() > 1:
from pytorch_lightning.trainer.states import TrainerState
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
save_dir = self._filepath_to_dir(ckpt_path)

if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()

_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=stage_is_fit, load_lr_scheduler_states=stage_is_fit
)

# restore datamodule states
if self.lightning_module.trainer.datamodule is not None:
self.lightning_module.trainer.datamodule.on_load_checkpoint(client_state)

# hook: give user access to checkpoint if needed.
self.lightning_module.on_load_checkpoint(client_state)
return client_state, False
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
if self._original_accumulate_grad_batches is None:
return super().update_global_step(total_batch_idx, current_global_step)
else:
if total_batch_idx % self._original_accumulate_grad_batches == 0:
current_global_step += 1
return current_global_step
Loading