diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8056aa7039d9b..110752a95cda6 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -62,6 +62,11 @@ jobs: pip list displayName: 'Install dependencies' + - bash: | + # Temporary fix till DeepSpeed release, move this into CUDA image + pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb + displayName: 'Install DeepSpeed' + - script: | python tests/collect_env_details.py displayName: 'Env details' @@ -76,7 +81,9 @@ jobs: python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 displayName: 'Testing: standard' - - script: | + - bash: | + # Required for Ninja binary for building extensions, which is installed at this location + export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin sh tests/special_tests.sh displayName: 'Testing: special' diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 6df26ab4cb689..8424d5f1f8dd1 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -613,6 +613,8 @@ Lightning currently offers the following methods to leverage model parallelism: - Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**) - Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential ` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization) +.. _sharded: + Sharded Training ^^^^^^^^^^^^^^^^ Lightning integration of optimizer sharded training provided by `FairScale `_. @@ -678,6 +680,149 @@ Sharded Training can work across all DDP variants by adding the additional ``--p Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required. +---------- + +.. _deep_speed: + +DeepSpeed +^^^^^^^^^ + +.. note:: + The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue `_ if you run into any issues. + +`DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam `_. +Using the plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark `_ and the DeepSpeed `docs `_. +We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations, primarily due to FairScale Sharded ease of use in scenarios such as multiple optimizers/schedulers. + +To use DeepSpeed, you first need to install DeepSpeed using the commands below. + +.. code-block:: bash + + pip install deepspeed mpi4py + +If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). +Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. + +.. note:: + Currently ``resume_from_checkpoint`` and manual optimization are not supported. + + DeepSpeed only supports single optimizer, single scheduler. + +ZeRO-Offload +"""""""""""" + +Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. +For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload. + +.. note:: + To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. `_. + +.. code-block:: python + + from pytorch_lightning import Trainer + + model = MyModel() + trainer = Trainer(gpus=4, plugins='deepspeed', precision=16) + trainer.fit(model) + + +This can also be done via the command line using a Pytorch Lightning script: + +.. code-block:: bash + + python train.py --plugins deepspeed --precision 16 --gpus 4 + + +You can also modify the ZeRO-Offload parameters via the plugin as below. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) + trainer.fit(model) + + +.. note:: + We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size. + These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed. + + DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameters. + + The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``. + + +Custom DeepSpeed Config +""""""""""""""""""""""" + +DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam `_. + +.. note:: + All plugin default parameters will be ignored when a config object is passed. + All compatible arguments can be seen in the `DeepSpeed docs `_. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + deepspeed_config = { + "zero_allow_untested_optimizer": True, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 3e-5, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": 1e-9, + "cuda_aware": True, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + }, + "zero_optimization": { + "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) + "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU + "contiguous_gradients": True, # Reduce gradient fragmentation. + "overlap_comm": True, # Overlap reduce/backward operation of gradients for speed. + "allgather_bucket_size": 2e8, # Number of elements to all gather at once. + "reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once. + } + } + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16) + trainer.fit(model) + + +We support taking the config as a json formatted file: + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16) + trainer.fit(model) + + +You can use also use an environment variable via your PyTorch Lightning script: + +.. code-block:: bash + + PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed + + ---------- .. _sequential-parallelism: diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 967b6a85c878b..3b95668aa9cad 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -284,7 +284,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): - optimizer.step(closure=lambda_closure, **kwargs) + self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients""" @@ -315,9 +315,11 @@ def setup_optimizers(self, trainer: "Trainer"): trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if trainer.testing is True: + if trainer.testing: return - optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module) + optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( + trainer=trainer, model=self.lightning_module + ) self.optimizers = optimizers self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 2d9086c2e18ad..dec672d025294 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,5 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -7,6 +8,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 @@ -25,6 +27,8 @@ "DDP2Plugin", "DDPPlugin", "DDPSpawnPlugin", + "DeepSpeedPlugin", + "DeepSpeedPrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index 1b085c92aafd6..fc60deffcbb77 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,4 +1,5 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py new file mode 100644 index 0000000000000..711ede2f7ded4 --- /dev/null +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -0,0 +1,61 @@ +from typing import Callable, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() + + +class DeepSpeedPrecisionPlugin(PrecisionPlugin): + + def __init__(self, precision): + super().__init__() + self.precision = precision + + def pre_optimizer_step( + self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs + ) -> bool: + deepspeed_engine = pl_module.trainer.model + # DeepSpeed not support closures. + lambda_closure() + + if not pl_module.automatic_optimization: + pl_module.trainer.call_hook("on_after_backward") + + deepspeed_engine.step() + + return False + + def backward( + self, + lightning_module: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + if is_overridden('backward', lightning_module): + warning_cache.warn( + "Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" + "backward logic outside of the LightningModule" + ) + # todo: hack around for deepspeed engine to call backward + deepspeed_engine = lightning_module.trainer.model + deepspeed_engine.backward(closure_loss, **kwargs) + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + + return closure_loss + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + """ + DeepSpeed handles clipping gradients via the training type plugin. + """ + pass diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index a5a644fc6568c..b73c6351de181 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -1,6 +1,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py new file mode 100644 index 0000000000000..354ef5944ef42 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -0,0 +1,323 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn.parallel import DistributedDataParallel + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE + +if _DEEPSPEED_AVAILABLE: + import deepspeed + + +class LightningDeepSpeedModule(_LightningModuleWrapperBase): + + def __init__(self, pl_module: LightningModule, precision: int): + super().__init__(pl_module) + self.precision = precision + + def forward(self, *inputs, **kwargs): + if self.precision == 16: + inputs = self._move_float_tensors_to_half(inputs) + + return super().forward(*inputs, **kwargs) + + @staticmethod + def batch_to(data): + return data.half() + + def _move_float_tensors_to_half(self, batch: Any): + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) + return batch + + +class DeepSpeedPlugin(DDPPlugin): + distributed_backend = "deepspeed" + DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" + + def __init__( + self, + zero_optimization: bool = True, + stage: int = 2, + cpu_offload: bool = True, + contiguous_gradients: bool = True, + overlap_comm: bool = True, + allgather_partitions: bool = True, + reduce_scatter: bool = True, + allgather_bucket_size: int = 2e8, + reduce_bucket_size: int = 2e8, + zero_allow_untested_optimizer: bool = True, + config: Optional[Union[Path, str, dict]] = None, + logging_level: int = logging.WARN, + num_nodes: int = 1, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + ) -> None: + """ + + Provides capabilities to run training using the DeepSpeed library, + with training optimizations for large billion parameter models. + `For more information: https://www.deepspeed.ai/`. + + .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change. + + Defaults have been set to enable ZeRO-Offload and some have been taken from the link below. + These defaults have been set generally, but may require tuning for optimum performance based on your model size. + `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`. + + Arguments: + + zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True) + + stage: Different stages of the ZeRO Optimizer. 0 is disabled, + 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2) + + cpu_offload: Enable offloading optimizer memory and computation to CPU (default: True) + + contiguous_gradients: Copies gradients to a continuous buffer as they are produced. + Avoids memory fragmentation during backwards. Useful when training large models. (default: True) + + overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation. + This is a speed optimization when training across multiple GPUs/machines. (default: True) + + allgather_partitions: All gather updated parameters at the end of training step, + instead of using a series of broadcast collectives (default: True) + + reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True) + + allgather_bucket_size: Number of elements to allgather at once. + Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8) + + reduce_bucket_size: Number of elements to reduce at once. + Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8) + + zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a + DeepSpeed supported optimizer when using ZeRO (default: True) + + config: Pass in a deepspeed formatted config dict, + or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. + All defaults will be ignored if a config is passed in. (Default: ``None``) + + logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) + + """ + if not _DEEPSPEED_AVAILABLE: + raise MisconfigurationException( + "To use the DeepSpeed plugin, you must have DeepSpeed installed." + " pip install deepspeed mpi4py" + ) + super().__init__( + parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment + ) + self.config = self._load_config(config) + if self.config is None: + # User has not overridden config, set defaults + self.config = self._create_default_config( + zero_optimization, + zero_allow_untested_optimizer, + stage=stage, + cpu_offload=cpu_offload, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size + ) + self._config_initialized = False + deepspeed.utils.logging.logger.setLevel(logging_level) + + def _load_config(self, config): + if config is None and self.DEEPSPEED_ENV_VAR in os.environ: + rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") + config = os.environ[self.DEEPSPEED_ENV_VAR] + if isinstance(config, str) or isinstance(config, Path): + if os.path.exists(config): + with open(config) as f: + config = json.load(f) + else: + raise MisconfigurationException( + f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" + ) + return config + + def pre_dispatch(self): + self.set_world_ranks() + self.init_ddp_connection(self.global_rank, self.world_size) + + self.init_deepspeed() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + self.barrier() + + def init_deepspeed(self): + if not self._config_initialized: + self._format_config() + self._config_initialized = True + + precision = self.lightning_module.trainer.accelerator_backend.precision + model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + + if self.lightning_module.trainer.training: + self._initialize_deepspeed_train(model) + else: + self._initialize_deepspeed_inference(model) + + def _init_scheduler_optimizer(self): + optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( + self.lightning_module + ) + if (len(optimizers) != 1) or len(schedulers) > 1: + raise MisconfigurationException( + "DeepSpeed currently only supports single optimizer, single optional scheduler." + ) + scheduler = schedulers[0]['scheduler'] if len(schedulers) == 1 else None + optimizer = optimizers[0] + return optimizer, scheduler, optimizer_frequencies + + def _initialize_deepspeed_train(self, model): + optimizer, lightning_scheduler, optimizer_frequencies = None, None, None + if "optimizer" not in self.config: + rank_zero_info( + "You have not specified an optimizer or scheduler within the DeepSpeed config." + "Using `configure_optimizers` to define optimizer and scheduler." + ) + optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer() + model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) + model, optimizer, _, lr_scheduler = deepspeed.initialize( + args=SimpleNamespace(local_rank=self.local_rank), + model=model, + model_parameters=model_parameters, + optimizer=optimizer, + lr_scheduler=lightning_scheduler, + config_params=self.config, + ) + + # set optimizer for save/load, but deepspeed manages the specific optimizer logic + trainer = self.lightning_module.trainer + trainer.optimizers = [optimizer] + self.model = model + + def _initialize_deepspeed_inference(self, model): + # move the model to the correct device + self.model_to_device() + + self.pre_configure_ddp() + self._model = DistributedDataParallel( + model, + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + + def configure_scheduler(self, lr_scheduler): + # this duplicates the defaults from init_optimizers + scheduler = { + 'scheduler': lr_scheduler, + 'name': None, # no custom name + 'interval': 'epoch', # after epoch is over + 'frequency': 1, # every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': None, # value to monitor for ReduceLROnPlateau + 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau + } + return [scheduler] + + @property + def lightning_module(self): + # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early + module = getattr(self.model, "module", self.model) + return module.module if isinstance(module, LightningDeepSpeedModule) else module + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) + return distributed_sampler_kwargs + + def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]: + # Skip initializing optimizers here as DeepSpeed handles optimizers via config. + # User may have specified config options instead in configure_optimizers, but this is handled + # via `_initialize_deepspeed_train` + return [], [], [] # empty optimizers, schedulers and frequencies + + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + # note: We rely on the deepspeed engine to carry out the step rather than the optimizer. + # internally, the engine has a reference to the optimizer already. + self.model.step(**kwargs) + + def _format_config(self): + if self.config is None: + raise MisconfigurationException( + "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." + " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" + ) + self._format_batch_size_and_grad_accum_config() + self._format_precision_config() + + def _format_batch_size_and_grad_accum_config(self): + if "gradient_accumulation_steps" in self.config: + raise MisconfigurationException( + "Within the DeepSpeed config, do not set gradient_accumulation_steps" + " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." + ) + if "train_micro_batch_size_per_gpu" not in self.config: + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed + batch_size = self.lightning_module.train_dataloader().batch_size + self.config["train_micro_batch_size_per_gpu"] = batch_size + self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches + if "gradient_clipping" not in self.config: + self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + + def _format_precision_config(self): + + amp_type = self.lightning_module.trainer.accelerator_connector.amp_type + amp_level = self.lightning_module.trainer.accelerator_connector.amp_level + precision = self.lightning_module.trainer.accelerator_connector.precision + if precision == 16: + if "amp" not in self.config and amp_type == AMPType.NATIVE: + self.config["fp16"] = {"enabled": True} + elif "apex" not in self.config and amp_type == AMPType.APEX: + self.config["amp"] = { + "enabled": True, + "opt_level": amp_level, + } + if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config): + raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.") + + def _create_default_config( + self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs + ) -> Dict: + if zero_optimization: + return {"zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs} + return {} diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 938a17249e9f6..d7c3b4d4d77e1 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module @@ -152,3 +152,9 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """ return dataloader + + def init_optimizers(self, trainer: "Trainer", model: LightningModule): + return trainer.init_optimizers(model) + + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + optimizer.step(closure=lambda_closure, **kwargs) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 2c4eafb6ed0e8..5549a29473fc7 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -30,8 +30,11 @@ DDPShardedPlugin, DDPSpawnPlugin, DDPSpawnShardedPlugin, + DeepSpeedPlugin, + DeepSpeedPrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, + Plugin, PrecisionPlugin, ShardedNativeMixedPrecisionPlugin, SingleDevicePlugin, @@ -144,7 +147,7 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp - def handle_given_plugins(self, plugins: Optional[Sequence]): + def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]): plugins = plugins if plugins is not None else [] if isinstance(plugins, str): @@ -243,7 +246,7 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN + DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED ) @property @@ -254,6 +257,10 @@ def use_ddp2(self) -> bool: def use_horovod(self) -> bool: return self._distrib_type == DistributedType.HOROVOD + @property + def use_deepspeed(self) -> bool: + return self._distrib_type == DistributedType.DEEPSPEED + @property def is_distributed(self) -> bool: is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod @@ -290,15 +297,19 @@ def is_using_torchelastic(self) -> bool: return te_flags_passed def select_precision_plugin(self) -> PrecisionPlugin: + # set precision type + self.amp_type = AMPType.from_str(self.amp_type) + + if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): + return DeepSpeedPrecisionPlugin(self.precision) + if self.precision == 32: - self.amp_type = None return PrecisionPlugin() elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() - self.amp_type = AMPType(self.amp_type) if self.amp_type == AMPType.NATIVE: if self.on_cpu: raise MisconfigurationException( @@ -338,6 +349,12 @@ def select_precision_plugin(self) -> PrecisionPlugin: def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) + elif self.use_ddp and self.use_deepspeed: + plugin = DeepSpeedPlugin( + num_nodes=self.num_nodes, + cluster_environment=self.select_cluster_environment(), + parallel_devices=self.parallel_devices + ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 46ca290b24d34..2d6ddfd23abba 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,6 +29,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.plugins import Plugin from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator @@ -130,7 +131,7 @@ def __init__( terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, - plugins: Optional[Union[str, list]] = None, + plugins: Optional[Union[Plugin, str, list]] = None, amp_backend: str = 'native', amp_level: str = 'O2', distributed_backend: Optional[str] = None, diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 889ed96f43679..cf3aa06f305b8 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _APEX_AVAILABLE, _BOLTS_AVAILABLE, + _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _GROUP_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index c7796b433f1ed..3e4add4fb68d1 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -62,6 +62,7 @@ class DistributedType(LightningEnum): DDP = 'ddp' DDP2 = 'ddp2' DDP_SPAWN = 'ddp_spawn' + DEEPSPEED = 'deepspeed' HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 4d1b38eaf5949..b4c30097fad4e 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -55,6 +55,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized') _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') +_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py new file mode 100644 index 0000000000000..1d25c529dd963 --- /dev/null +++ b/tests/plugins/test_deepspeed_plugin.py @@ -0,0 +1,292 @@ +import json +import os + +import pytest +import torch +from torch import Tensor +from torch.optim import Optimizer + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel + + +@pytest.fixture +def deepspeed_config(): + return { + "optimizer": { + "type": "SGD", + "params": { + "lr": 3e-5, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + } + } + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_plugin_string(tmpdir): + """ + Test to ensure that the plugin can be passed via string, and parallel devices is correctly set. + """ + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins='deepspeed', + ) + + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_plugin(tmpdir): + """ + Test to ensure that the plugin can be passed directly, and parallel devices is correctly set. + """ + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin()], + ) + + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): + """ + Test to ensure that the plugin can be passed via a string with an environment variable. + """ + config_path = os.path.join(tmpdir, 'temp.json') + with open(config_path, 'w') as f: + f.write(json.dumps(deepspeed_config)) + monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins='deepspeed', + ) + + plugin = trainer.accelerator_backend.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert plugin.parallel_devices == [torch.device('cpu')] + assert plugin.config == deepspeed_config + + +@pytest.mark.parametrize( + "amp_backend", [ + pytest.param("native", marks=pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")), + pytest.param("apex", marks=pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex")), + ] +) +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +def test_deepspeed_precision_choice(amp_backend, tmpdir): + """ + Test to ensure precision plugin is also correctly chosen. + DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin + """ + + trainer = Trainer( + fast_dev_run=True, default_root_dir=tmpdir, plugins='deepspeed', amp_backend=amp_backend, precision=16 + ) + + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.accelerator_backend.precision_plugin.precision == 16 + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_with_invalid_config_path(tmpdir): + """ + Test to ensure if we pass an invalid config path we throw an exception. + """ + + with pytest.raises( + MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" + ): + DeepSpeedPlugin(config='invalid_path.json') + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): + """ + Test to ensure if we pass an env variable, we load the config from the path. + """ + config_path = os.path.join(tmpdir, 'temp.json') + with open(config_path, 'w') as f: + f.write(json.dumps(deepspeed_config)) + monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) + plugin = DeepSpeedPlugin() + assert plugin.config == deepspeed_config + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_defaults(tmpdir): + """ + Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed. + """ + plugin = DeepSpeedPlugin() + assert plugin.config is not None + assert isinstance(plugin.config["zero_optimization"], dict) + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_invalid_deepspeed_defaults_no_precision(tmpdir): + """ + Test to ensure that using defaults, if precision is not set to 16, we throw an exception. + """ + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins='deepspeed', + ) + with pytest.raises( + MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' + ): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_warn_deepspeed_override_backward(tmpdir): + """ + Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning. + """ + + class TestModel(BoringModel): + + def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + return loss.backward() + + model = TestModel() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins=DeepSpeedPlugin(zero_optimization=False), + gpus=1, + ) + with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_deepspeed_run_configure_optimizers(tmpdir): + """ + Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), + whilst using configure_optimizers for optimizers and schedulers. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) + assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally + # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler + assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR) + + model = TestModel() + trainer = Trainer( + plugins=DeepSpeedPlugin(zero_optimization=False), + default_root_dir=tmpdir, + gpus=1, + fast_dev_run=True, + ) + + trainer.fit(model) + + _assert_save_model_is_equal(model, tmpdir, trainer) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_deepspeed_config(tmpdir, deepspeed_config): + """ + Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers + and saves the model weights to load correctly. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + import deepspeed + assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) + assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally + assert isinstance(self.trainer.model.optimizer, torch.optim.SGD) + assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR) + + model = TestModel() + trainer = Trainer( + plugins=[DeepSpeedPlugin(config=deepspeed_config)], + default_root_dir=tmpdir, + gpus=1, + fast_dev_run=True, + ) + + trainer.fit(model) + trainer.test(model) + + _assert_save_model_is_equal(model, tmpdir, trainer) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_deepspeed_multigpu(tmpdir, deepspeed_config): + """ + Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation. + """ + model = BoringModel() + trainer = Trainer( + plugins=[DeepSpeedPlugin(zero_optimization=False)], + default_root_dir=tmpdir, + gpus=2, + fast_dev_run=True, + precision=16, + ) + trainer.fit(model) + trainer.test(model) + + _assert_save_model_is_equal(model, tmpdir, trainer) + + +def _assert_save_model_is_equal(model, tmpdir, trainer): + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + # carry out the check only on rank 0 + if trainer.global_rank == 0: + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + saved_model = saved_model.float() + model = model.float().cpu() + # Assert model parameters are identical after loading + for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(orig_param, trained_model_param) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index ff174b5cad648..472f7afda5e9e 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,6 +17,10 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp