Skip to content

Commit

Permalink
[fix] Ensure we check deepspeed/sharded in multinode DDP (#6297)
Browse files Browse the repository at this point in the history
* Ensure we check deepspeed/sharded in multinode

* Add CHANGELOG.md

* Add CHANGELOG.md

* Drop mock, use actual multi-gpu node
  • Loading branch information
SeanNaren authored and lexierule committed Mar 5, 2021
1 parent fc95f00 commit c5e9d67
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,12 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
if self.distributed_backend == "horovod":
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
if (self.num_nodes > 1 and self._distrib_type not in _ddp):
using_valid_distributed = self.use_ddp or self.use_ddp2
if self.num_nodes > 1 and not using_valid_distributed:
# throw error to force user to choose a supported distributed type such as ddp or ddp2
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
'Your chosen distributed type does not support num_nodes > 1. '
'Please set accelerator=ddp or accelerator=ddp2.'
)

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')
Expand Down
27 changes: 27 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
DDPPlugin,
DDPShardedPlugin,
DDPSpawnPlugin,
DDPSpawnShardedPlugin,
DeepSpeedPlugin,
PrecisionPlugin,
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -400,3 +404,26 @@ def test_plugin_accelerator_choice(accelerator, plugin):

trainer = Trainer(plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)


@pytest.mark.parametrize(["accelerator", "plugin"], [
('ddp', DDPPlugin),
('ddp_spawn', DDPSpawnPlugin),
('ddp_sharded', DDPShardedPlugin),
('ddp_sharded_spawn', DDPSpawnShardedPlugin),
pytest.param(
'deepspeed',
DeepSpeedPlugin,
marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
),
])
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_multi_node_gpu(mock_is_available, mock_device_count, accelerator, plugin, tmpdir):
trainer = Trainer(
accelerator=accelerator,
default_root_dir=tmpdir,
num_nodes=2,
gpus=2,
)
assert isinstance(trainer.training_type_plugin, plugin)

0 comments on commit c5e9d67

Please sign in to comment.