Skip to content

Commit

Permalink
[Fix] Move init dist connection into the setup function (#6506)
Browse files Browse the repository at this point in the history
* Move connection setup into the setup function. Call setup hook after we set up the accelerator

* Added CHANGELOG.md

* fix setup order in callback test

* fix input arguments in test

* Mock distributed function, remove protection to turn into training type hook

* Remove import

* Add missing mock, ensure custom plugin does not create children process

* Skip test on windows

* Update deepspeed to init connection in setup

* Do not initialize distributed module

* Move DeepSpeed tests to special tests since dist communication is being set up

* Special the test to see if this fixes CI

* Delete accelerator connector test to see if its causing build to fail

* Delete deepspeed test

* Revert "Delete accelerator connector test to see if its causing build to fail"

This reverts commit edde60b

* Revert "Delete deepspeed test"

This reverts commit 9d317429

* Reverse hook

* Reverse setup hooks to debug again

* Add todo so i know where i left off

* For single device move in pre_dispatch after setup function

* Add additional model to device hook if any additional parameters have been set

* See if we can enable deepspeed tests

* Revert "See if we can enable deepspeed tests"

This reverts commit b5450de

* See if this hook approach works

* Introduce new granular hooks

* Remove import, fix tpu spawn by moving the function to setup

* Added missing special test

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

(cherry picked from commit 4e9b453)
  • Loading branch information
SeanNaren authored and Borda committed Mar 23, 2021
1 parent caebaea commit f35dda8
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 95 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


## [1.2.4] - 2021-03-16

### Changed
Expand Down
26 changes: 11 additions & 15 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, Union
from typing import Any, Callable, Iterable, Optional, Union, Sequence

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -53,21 +53,20 @@ def __init__(
self.precision_plugin = precision_plugin
self.training_type_plugin = training_type_plugin

self.optimizers = None
self.lr_schedulers = None
self.optimizer_frequencies = None
self.optimizers: Sequence = []
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []

def setup(self, trainer, model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers
Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance to connect to
model: the model to train
trainer: the trainer instance
model: the LightningModule
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)
self.setup_precision_plugin(self.precision_plugin)

def start_training(self, trainer):
self.training_type_plugin.start_training(trainer)
Expand Down Expand Up @@ -319,11 +318,8 @@ def setup_optimizers(self, trainer):
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin
"""
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator."""
plugin.connect(model)

def connect_precision_plugin(self, plugin: PrecisionPlugin):
Expand Down
65 changes: 31 additions & 34 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

def setup_environment(self):
# start the other scripts
# TODO: refactor and let generic cluster env hold the information about who spawns the processes
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
Expand All @@ -97,6 +95,8 @@ def setup(self, model):
# set the task idx
self.task_idx = self.cluster_environment.local_rank()

self.setup_distributed()

def _call_children_scripts(self):

# bookkeeping of spawned processes
Expand Down Expand Up @@ -171,6 +171,34 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

def setup_distributed(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
Expand Down Expand Up @@ -226,37 +254,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,7 @@ def _load_config(self, config):
return config

def pre_dispatch(self):
self.set_world_ranks()
self.init_ddp_connection(self.global_rank, self.world_size)

self.init_deepspeed()

# set warning rank
rank_zero_only.rank = self.global_rank

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
self.barrier()

def init_deepspeed(self):
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,6 @@ def on_gpu(self):
def lightning_module(self):
return unwrap_lightning_module(self._model)

@abstractmethod
def setup(self, model):
raise NotImplementedError

def connect(self, model, *args, **kwargs):
self.setup(model)
return self.model

@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def model_to_device(self) -> None:

self._model.to(self.root_device)

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model

Expand Down
7 changes: 1 addition & 6 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ def __init__(self, device: Union[torch.device, int]):
def on_tpu(self) -> bool:
return True

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self._model

def model_to_device(self) -> None:
self._model.to(self.root_device)
self.model.to(self.root_device)

def pre_dispatch(self) -> None:
if isinstance(self.device, int):
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ def __init__(
self.tpu_local_core_rank = 0
self.start_method = None

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.create_mp_queue()
self._model = model
return self._model
return self.model

def create_mp_queue(self):
self.start_method = 'fork'
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,10 @@ def fit(
# ----------------------------
# SET UP TRAINING
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator.connect(model)
self.accelerator.setup_environment()
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

# ----------------------------
Expand Down
30 changes: 21 additions & 9 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
"SLURM_LOCALID": "10"
}
)
def test_accelerator_choice_ddp_slurm():
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -133,7 +134,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_slurm(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -162,7 +164,8 @@ def on_fit_start(self, trainer, pl_module):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU")
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -190,7 +193,8 @@ def on_fit_start(self, trainer, pl_module):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU")
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -221,7 +225,8 @@ def on_fit_start(self, trainer, pl_module):
"NODE_RANK": "0",
})
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -256,7 +261,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -291,7 +297,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock):
"""
Test that we choose the custom cluster even when SLURM or TE flags are around
"""
Expand All @@ -301,6 +308,9 @@ class CustomCluster(ClusterEnvironment):
def master_address(self):
return 'asdf'

def creates_children(self) -> bool:
return True

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
Expand Down Expand Up @@ -333,7 +343,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_custom_accelerator(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_custom_accelerator(device_count_mock, setup_distributed_mock):

class Accel(Accelerator):
pass
Expand Down Expand Up @@ -368,7 +379,8 @@ class TrainTypePlugin(SingleDevicePlugin):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_dist_backend_accelerator_mapping(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down
30 changes: 29 additions & 1 deletion tests/accelerators/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import os
import platform
from typing import Optional
from unittest import mock
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -98,7 +100,6 @@ def test_torch_distributed_backend_env_variables(tmpdir):
_environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"}
with patch.dict(os.environ, _environ), \
patch('torch.cuda.device_count', return_value=2):

with pytest.raises(ValueError, match="Invalid backend: 'undefined'"):
model = BoringModel()
trainer = Trainer(
Expand All @@ -109,3 +110,30 @@ def test_torch_distributed_backend_env_variables(tmpdir):
logger=False,
)
trainer.fit(model)


@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@mock.patch('torch.cuda.device_count', return_value=1)
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.set_device')
@mock.patch.dict(os.environ, {'PL_TORCH_DISTRIBUTED_BACKEND': 'gloo'}, clear=True)
def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir):
"""
Test to ensure torch distributed is available within the setup hook using ddp
"""

class TestModel(BoringModel):

def setup(self, stage: Optional[str] = None) -> None:
assert torch.distributed.is_initialized()
raise SystemExit()

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
accelerator="ddp",
gpus=1,
)
with pytest.raises(SystemExit):
trainer.fit(model)
Loading

0 comments on commit f35dda8

Please sign in to comment.