Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerator refactor sharded rpc #5854

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
if make_optimizer_step:
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)
Expand Down
31 changes: 22 additions & 9 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
if isinstance(plug, TrainingTypePlugin):
if training_type is None:
training_type = plug

else:
raise MisconfigurationException(
'You can only specify one precision and one training type plugin. '
Expand Down Expand Up @@ -191,20 +192,20 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):

self._training_type_plugin = training_type
self._precision_plugin = precision
self._cluster_environment = cluster_environment
self._cluster_environment = cluster_environment or self.select_cluster_environment()

@property
def precision_plugin(self) -> PrecisionPlugin:
if self._precision_plugin is None:
self._precision_plugin = self.select_precision_plugin()

return self._precision_plugin

@property
def training_type_plugin(self) -> TrainingTypePlugin:
if self._training_type_plugin is None:
self._training_type_plugin = self.select_training_type_plugin()

else:
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
return self._training_type_plugin

@property
Expand Down Expand Up @@ -283,9 +284,6 @@ def select_precision_plugin(self):
if self.on_tpu:
return TPUHalfPrecisionPlugin()

if isinstance(self.training_type_plugin, RPCPlugin):
raise MisconfigurationException

if self.amp_type == "native":
if not _NATIVE_AMP_AVAILABLE:
rank_zero_warn(
Expand Down Expand Up @@ -324,9 +322,8 @@ def select_precision_plugin(self):
raise NotImplementedError("We only support precisions 32 and 16!")

def select_training_type_plugin(self):
cluster_environment = self.select_cluster_environment()
if self.use_ddp2:
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment)
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self._cluster_environment)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
Expand Down Expand Up @@ -358,7 +355,7 @@ def select_training_type_plugin(self):
plugin = ddp_plugin_cls(
parallel_devices=self.parallel_devices,
num_nodes=self.num_nodes,
cluster_environment=cluster_environment,
cluster_environment=self.select_cluster_environment(),
sync_batchnorm=self.sync_batchnorm,
)
elif self.use_dp:
Expand All @@ -371,6 +368,22 @@ def select_training_type_plugin(self):
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
return plugin


def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
# necessary for RPC, when user has to provide balance
if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
training_type.parallel_devices = self.parallel_devices
if hasattr(training_type, 'num_processes'):
training_type.num_processes = len(self.parallel_devices)

if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None:
training_type.cluster_environment = self.select_cluster_environment()

if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
training_type.num_nodes = self.num_nodes

return training_type

def select_accelerator(self):
if isinstance(self.distributed_backend, Accelerator):
# custom accelerator from user
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
self._has_spawned_children = False
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices)
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if parallel_devices can be None I think we need to change the root_devices access property
because in ddp root_device = parallel_devices[local_rank]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They can't be none. They are optional for constructor, since these plugins will be provided by the user, but before calling anything else, they will be set in the accelerator_connector.


@property
def root_device(self):
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from contextlib import suppress
from typing import Optional
from typing import Optional, Sequence

import torch

Expand All @@ -40,11 +40,11 @@ class RPCPlugin(DDPPlugin):

def __init__(
self,
parallel_devices,
num_nodes=1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm=False,
rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC,
parallel_devices : Sequence[int] = (),
num_nodes: Optional[int] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
sync_batchnorm: Optional[bool] = None,
**kwargs
):
self.rpc_timeout_sec = rpc_timeout_sec
Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License
import logging
import os
from typing import Any, List, Optional
from typing import Any, List, Optional, Sequence

import torch
import torch.distributed as torch_distrib
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Optimizer
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
Expand All @@ -42,11 +43,7 @@ class RPCSequentialPlugin(RPCPlugin):

def __init__(
self,
parallel_devices,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm=False,
balance: Optional[List[int]] = None,
balance : List[int],
microbatches: int = 8,
checkpoint: str = 'except_last',
balance_mode: str = "balance_by_size",
Expand Down Expand Up @@ -93,10 +90,6 @@ def __init__(
"""
self._check_pipe_available()
super().__init__(
parallel_devices=parallel_devices,
num_nodes=num_nodes,
cluster_environment=cluster_environment,
sync_batchnorm=sync_batchnorm,
rpc_timeout_sec=rpc_timeout_sec,
**kwargs
)
Expand Down Expand Up @@ -324,6 +317,12 @@ def _check_pipe_available(self):
'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.'
)

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None:
"""Hook to do something after each optimizer step."""
if self.rpc_enabled and self.is_main_rpc_process:

# Initialize optimizer step on main process
self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs)

class LightningPipeModule(nn.Module):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti
def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run after precision plugin executes backward"""

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None:
"""Hook to do something after each optimizer step."""

@property
def model(self) -> Module:
"""Returns the potentially wrapped LightningModule"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class DistributedType(LightningEnum):
HOROVOD = 'horovod'
DDP_SHARDED = 'ddp_sharded'
DDP_SHARDED_SPAWN = 'ddp_sharded_spawn'
RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential'


class DeviceType(LightningEnum):
Expand Down
12 changes: 6 additions & 6 deletions tests/plugins/legacy/test_ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import nn

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins.legacy.ddp_sequential_plugin import DDPSequentialPlugin
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import RandomDataset
Expand All @@ -47,7 +47,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)],
plugins=[RPCSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)],
enable_pl_optimizer=True,
)

Expand Down Expand Up @@ -77,15 +77,15 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
precision=16,
amp_backend="native",
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1])],
plugins=[RPCSequentialPlugin(balance=[2, 1])],
)
try:
trainer.fit(model)

assert len(trainer.dev_debugger.pbar_added_metrics) > 0

except MisconfigurationException as e:
assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
assert str(e) == 'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision'


@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
Expand All @@ -102,7 +102,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1])],
plugins=[RPCSequentialPlugin(balance=[2, 1])],
)

trainer.fit(model)
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None):
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 2])],
plugins=[RPCSequentialPlugin(balance=[2, 2])],
)

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/legacy/test_rpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCPlugin
from pytorch_lightning.utilities import _RPC_AVAILABLE
from tests.base.boring_model import BoringModel

Expand Down
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_
python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp
python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
Expand Down