Skip to content

Commit

Permalink
DeepSpeed Integration (#5954)
Browse files Browse the repository at this point in the history
* Add initial deepspeed changes

* Address code review

* Move static method outside of function

* Fixes

* Add missing annotation

* Remove seed setting

* Doc changes

* Doc changes, add address reviews

* Fix docs

* Try fixing issue by moving to torch adam

* Clean up check

* Changes, better APIs!

* Add wrapper, swap to git install revision

* Add special test

* Add warning

* Address review

* Add better disclaimer

* Turn off ZeRO for testing due to compilation

* Add description on modifying parameters via the plugin

* Doc strings clear

* Small doc fixes

* Fix hash, reduce test

* Added CI change

* Move to azure pipeline

* Fix test name

* Add missing flag

* Remove sudo...

* Try conda instead

* Swap to conda base

* Try suggested install

* Apply suggestions from code review

* Apply suggestions from code review

* Revert "Apply suggestions from code review"

This reverts commit 41cca05

* Revert "Apply suggestions from code review"

This reverts commit e06ec29

* Remove setter

* Address most review

* Move out function, remove DeepSpeed from requirements

* Install deepspeed/mpi4py within container

* Use special tests, move to master commit for deepspeed

* Export path

* Force compile to happen first

* Remove!

* Debugging ninja

* Fix error in optimizer step logic

* Attempt to fix symbolic link

* Reverse to aid debugging

* Export path again

* Clean up mess

* var

* Revert "var"

This reverts commit 3450eac

* Address review, add todo

* Add note about unsupported functionality

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
4 people authored Feb 17, 2021
1 parent 6a409c7 commit 7189d67
Show file tree
Hide file tree
Showing 16 changed files with 877 additions and 10 deletions.
9 changes: 8 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ jobs:
pip list
displayName: 'Install dependencies'
- bash: |
# Temporary fix till DeepSpeed release, move this into CUDA image
pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb
displayName: 'Install DeepSpeed'
- script: |
python tests/collect_env_details.py
displayName: 'Env details'
Expand All @@ -76,7 +81,9 @@ jobs:
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50
displayName: 'Testing: standard'
- script: |
- bash: |
# Required for Ninja binary for building extensions, which is installed at this location
export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin
sh tests/special_tests.sh
displayName: 'Testing: special'
Expand Down
145 changes: 145 additions & 0 deletions docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ Lightning currently offers the following methods to leverage model parallelism:
- Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**)
- Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential <torch.nn.Sequential>` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization)

.. _sharded:

Sharded Training
^^^^^^^^^^^^^^^^
Lightning integration of optimizer sharded training provided by `FairScale <https://github.com/facebookresearch/fairscale>`_.
Expand Down Expand Up @@ -678,6 +680,149 @@ Sharded Training can work across all DDP variants by adding the additional ``--p

Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.

----------

.. _deep_speed:

DeepSpeed
^^^^^^^^^

.. note::
The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues>`_ if you run into any issues.

`DeepSpeed <https://github.com/microsoft/DeepSpeed>`_ offers additional CUDA deep learning training optimizations, similar to `FairScale <https://github.com/facebookresearch/fairscale>`_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.
Using the plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark <https://github.com/huggingface/transformers/issues/9996>`_ and the DeepSpeed `docs <https://www.deepspeed.ai/tutorials/megatron/>`_.
We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations, primarily due to FairScale Sharded ease of use in scenarios such as multiple optimizers/schedulers.

To use DeepSpeed, you first need to install DeepSpeed using the commands below.

.. code-block:: bash
pip install deepspeed mpi4py
If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``).
Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``.

.. note::
Currently ``resume_from_checkpoint`` and manual optimization are not supported.

DeepSpeed only supports single optimizer, single scheduler.

ZeRO-Offload
""""""""""""

Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tutorials/zero-offload/>`_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.
For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload.

.. note::
To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`_.

.. code-block:: python
from pytorch_lightning import Trainer
model = MyModel()
trainer = Trainer(gpus=4, plugins='deepspeed', precision=16)
trainer.fit(model)
This can also be done via the command line using a Pytorch Lightning script:

.. code-block:: bash
python train.py --plugins deepspeed --precision 16 --gpus 4
You can also modify the ZeRO-Offload parameters via the plugin as below.

.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16)
trainer.fit(model)
.. note::
We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size.
These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed.

DeepSpeed allocates a reduce buffer size `multiplied by 4.5x <https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage2.py#L1594-L1607>`_ so take that into consideration when tweaking the parameters.

The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``.


Custom DeepSpeed Config
"""""""""""""""""""""""

DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.

.. note::
All plugin default parameters will be ignored when a config object is passed.
All compatible arguments can be seen in the `DeepSpeed docs <https://www.deepspeed.ai/docs/config-json/>`_.

.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
deepspeed_config = {
"zero_allow_untested_optimizer": True,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 3e-5,
"betas": [0.998, 0.999],
"eps": 1e-5,
"weight_decay": 1e-9,
"cuda_aware": True,
},
},
'scheduler': {
"type": "WarmupLR",
"params": {
"last_batch_iteration": -1,
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 100,
}
},
"zero_optimization": {
"stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
"cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU
"contiguous_gradients": True, # Reduce gradient fragmentation.
"overlap_comm": True, # Overlap reduce/backward operation of gradients for speed.
"allgather_bucket_size": 2e8, # Number of elements to all gather at once.
"reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once.
}
}
model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16)
trainer.fit(model)
We support taking the config as a json formatted file:

.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16)
trainer.fit(model)
You can use also use an environment variable via your PyTorch Lightning script:

.. code-block:: bash
PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed
----------

.. _sequential-parallelism:
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients"""
Expand Down Expand Up @@ -315,9 +315,11 @@ def setup_optimizers(self, trainer: "Trainer"):
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing is True:
if trainer.testing:
return
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module)
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
Expand All @@ -25,6 +27,8 @@
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
61 changes: 61 additions & 0 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Callable, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class DeepSpeedPrecisionPlugin(PrecisionPlugin):

def __init__(self, precision):
super().__init__()
self.precision = precision

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
) -> bool:
deepspeed_engine = pl_module.trainer.model
# DeepSpeed not support closures.
lambda_closure()

if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

deepspeed_engine.step()

return False

def backward(
self,
lightning_module: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
if is_overridden('backward', lightning_module):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
"backward logic outside of the LightningModule"
)
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = lightning_module.trainer.model
deepspeed_engine.backward(closure_loss, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()

return closure_loss

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
pass
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand Down
Loading

0 comments on commit 7189d67

Please sign in to comment.