From fe61a7fcea3fa35a51e34d1de56da3c6761e9fae Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 6 Feb 2021 15:51:54 +0000 Subject: [PATCH 01/10] rpc branch --- pytorch_lightning/accelerators/accelerator.py | 1 + .../accelerators/accelerator_connector.py | 22 +++++++++++-------- .../plugins/training_type/ddp.py | 2 +- .../plugins/training_type/rpc_sequential.py | 14 ++++++------ .../training_type/training_type_plugin.py | 3 +++ pytorch_lightning/utilities/enums.py | 1 + .../legacy/test_ddp_sequential_plugin.py | 12 +++++----- tests/plugins/legacy/test_rpc_plugin.py | 2 +- tests/special_tests.sh | 2 +- 9 files changed, 34 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a8e63776f93d8..968469f104cba 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -279,6 +279,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal if make_optimizer_step: self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs) self.precision_plugin.post_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 8393a21104704..b2a43b3083d80 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -191,20 +191,20 @@ def handle_given_plugins(self, plugins: Optional[Sequence]): self._training_type_plugin = training_type self._precision_plugin = precision - self._cluster_environment = cluster_environment + self._cluster_environment = cluster_environment or self.select_cluster_environment() @property def precision_plugin(self) -> PrecisionPlugin: if self._precision_plugin is None: self._precision_plugin = self.select_precision_plugin() - return self._precision_plugin @property def training_type_plugin(self) -> TrainingTypePlugin: if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() - + else: + self.resolve_training_type_plugin() return self._training_type_plugin @property @@ -283,9 +283,6 @@ def select_precision_plugin(self): if self.on_tpu: return TPUHalfPrecisionPlugin() - if isinstance(self.training_type_plugin, RPCPlugin): - raise MisconfigurationException - if self.amp_type == "native": if not _NATIVE_AMP_AVAILABLE: rank_zero_warn( @@ -324,9 +321,8 @@ def select_precision_plugin(self): raise NotImplementedError("We only support precisions 32 and 16!") def select_training_type_plugin(self): - cluster_environment = self.select_cluster_environment() if self.use_ddp2: - plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment) + plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self._cluster_environment) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic @@ -358,7 +354,7 @@ def select_training_type_plugin(self): plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, - cluster_environment=cluster_environment, + cluster_environment=self._cluster_environment, sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: @@ -371,6 +367,14 @@ def select_training_type_plugin(self): plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) return plugin + def resolve_training_type_plugin(self): + if self._training_type_plugin.num_processes is None: + self._training_type_plugin.num_processes = len(self.parallel_devices) + self._training_type_plugin.parallel_devices = self.parallel_devices + + if self._training_type_plugin.cluster_environment is None: + self._training_type_plugin.cluster_environment = self._cluster_environment + def select_accelerator(self): if isinstance(self.distributed_backend, Accelerator): # custom accelerator from user diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 274078d8a80d4..e2571b8a96eac 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -73,7 +73,7 @@ def __init__( self._has_spawned_children = False self.task_idx = None self.node_rank = 0 - self.num_processes = len(parallel_devices) + self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 79cecac3fbb4d..850a63cede1b4 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule +from torch.optim import Optimizer from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin @@ -42,10 +43,6 @@ class RPCSequentialPlugin(RPCPlugin): def __init__( self, - parallel_devices, - num_nodes: int = 1, - cluster_environment: ClusterEnvironment = None, - sync_batchnorm=False, balance: Optional[List[int]] = None, microbatches: int = 8, checkpoint: str = 'except_last', @@ -93,9 +90,6 @@ def __init__( """ self._check_pipe_available() super().__init__( - parallel_devices=parallel_devices, - num_nodes=num_nodes, - cluster_environment=cluster_environment, sync_batchnorm=sync_batchnorm, rpc_timeout_sec=rpc_timeout_sec, **kwargs @@ -324,6 +318,12 @@ def _check_pipe_available(self): 'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.' ) + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: + """Hook to do something after each optimizer step.""" + if self.rpc_enabled and self.is_main_rpc_process: + + # Initialize optimizer step on main process + self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs) class LightningPipeModule(nn.Module): """ diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c26f5fbc1b743..acc8288a31ada 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -75,6 +75,9 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run after precision plugin executes backward""" + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: + """Hook to do something after each optimizer step.""" + @property def model(self) -> Module: """Returns the potentially wrapped LightningModule""" diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 6c539dec7fd3a..c7796b433f1ed 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -65,6 +65,7 @@ class DistributedType(LightningEnum): HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' + RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' class DeviceType(LightningEnum): diff --git a/tests/plugins/legacy/test_ddp_sequential_plugin.py b/tests/plugins/legacy/test_ddp_sequential_plugin.py index 8c6061d12cf11..00e163827269e 100644 --- a/tests/plugins/legacy/test_ddp_sequential_plugin.py +++ b/tests/plugins/legacy/test_ddp_sequential_plugin.py @@ -20,7 +20,7 @@ from torch import nn from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.plugins.legacy.ddp_sequential_plugin import DDPSequentialPlugin +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import RandomDataset @@ -47,7 +47,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], + plugins=[RPCSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], enable_pl_optimizer=True, ) @@ -77,7 +77,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): precision=16, amp_backend="native", distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], + plugins=[RPCSequentialPlugin(balance=[2, 1])], ) try: trainer.fit(model) @@ -85,7 +85,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): assert len(trainer.dev_debugger.pbar_added_metrics) > 0 except MisconfigurationException as e: - assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision' + assert str(e) == 'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision' @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") @@ -102,7 +102,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], + plugins=[RPCSequentialPlugin(balance=[2, 1])], ) trainer.fit(model) @@ -130,7 +130,7 @@ def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 2])], + plugins=[RPCSequentialPlugin(balance=[2, 2])], ) try: diff --git a/tests/plugins/legacy/test_rpc_plugin.py b/tests/plugins/legacy/test_rpc_plugin.py index 77937c16058dc..0bc17ca8ea2fa 100644 --- a/tests/plugins/legacy/test_rpc_plugin.py +++ b/tests/plugins/legacy/test_rpc_plugin.py @@ -7,7 +7,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin from pytorch_lightning.utilities import _RPC_AVAILABLE from tests.base.boring_model import BoringModel diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 577e49cec49d2..18c3feb415b33 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # Running special tests -set -e +#set -e export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp From b871d7100eebe3d117773e070f2fd2a79a79d4ae Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:06:35 +0100 Subject: [PATCH 02/10] merge --- .../accelerators/accelerator_connector.py | 13 +++++++++++++ .../plugins/training_type/rpc_sequential.py | 9 ++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index b2a43b3083d80..5485f47bdb357 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -161,6 +161,19 @@ def handle_given_plugins(self, plugins: Optional[Sequence]): if isinstance(plug, TrainingTypePlugin): if training_type is None: training_type = plug + + # necessary for RPC, when user has to provide balance + if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + training_type.parallel_devices = self.parallel_devices + if hasattr(training_type, 'num_processes'): + training_type.num_processes = len(self.parallel_devices) + + if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: + training_type.cluster_environment = cluster_environment + + if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: + training_type.num_nodes = self.num_nodes + else: raise MisconfigurationException( 'You can only specify one precision and one training type plugin. ' diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 850a63cede1b4..6ff058b2d6c21 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import Any, List, Optional +from typing import Any, List, Optional, Sequence import torch import torch.distributed as torch_distrib @@ -43,7 +43,7 @@ class RPCSequentialPlugin(RPCPlugin): def __init__( self, - balance: Optional[List[int]] = None, + balance : List[int], microbatches: int = 8, checkpoint: str = 'except_last', balance_mode: str = "balance_by_size", @@ -90,7 +90,10 @@ def __init__( """ self._check_pipe_available() super().__init__( - sync_batchnorm=sync_batchnorm, + parallel_devices=(), + num_nodes=None, + cluster_environment=None, + sync_batchnorm=False, rpc_timeout_sec=rpc_timeout_sec, **kwargs ) From dbf08d03d311e1be55b70244c11d3b380b9c2ad8 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:17:23 +0100 Subject: [PATCH 03/10] update handling of rpc --- .../accelerators/accelerator_connector.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 5485f47bdb357..983ed1616ee1d 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -162,17 +162,7 @@ def handle_given_plugins(self, plugins: Optional[Sequence]): if training_type is None: training_type = plug - # necessary for RPC, when user has to provide balance - if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): - training_type.parallel_devices = self.parallel_devices - if hasattr(training_type, 'num_processes'): - training_type.num_processes = len(self.parallel_devices) - - if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: - training_type.cluster_environment = cluster_environment - - if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: - training_type.num_nodes = self.num_nodes + else: raise MisconfigurationException( @@ -217,7 +207,7 @@ def training_type_plugin(self) -> TrainingTypePlugin: if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() else: - self.resolve_training_type_plugin() + self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) return self._training_type_plugin @property @@ -380,13 +370,21 @@ def select_training_type_plugin(self): plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) return plugin - def resolve_training_type_plugin(self): - if self._training_type_plugin.num_processes is None: - self._training_type_plugin.num_processes = len(self.parallel_devices) - self._training_type_plugin.parallel_devices = self.parallel_devices - - if self._training_type_plugin.cluster_environment is None: - self._training_type_plugin.cluster_environment = self._cluster_environment + + def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: + # necessary for RPC, when user has to provide balance + if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + training_type.parallel_devices = self.parallel_devices + if hasattr(training_type, 'num_processes'): + training_type.num_processes = len(self.parallel_devices) + + if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: + training_type.cluster_environment = self.select_cluster_environment() + + if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: + training_type.num_nodes = self.num_nodes + + return training_type def select_accelerator(self): if isinstance(self.distributed_backend, Accelerator): From 17fdb38371bb84d2e84e7988cf27013b98eeae10 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:31:42 +0100 Subject: [PATCH 04/10] make devices etc. Optional in RPC --- pytorch_lightning/plugins/training_type/rpc.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 4aff83189b6bc..d1510414eba2b 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -40,20 +40,16 @@ class RPCPlugin(DDPPlugin): def __init__( self, - parallel_devices, - num_nodes=1, - cluster_environment: ClusterEnvironment = None, - sync_batchnorm=False, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs ): self.rpc_timeout_sec = rpc_timeout_sec self._is_rpc_initialized = False super().__init__( - parallel_devices=parallel_devices, - num_nodes=num_nodes, - cluster_environment=cluster_environment, - sync_batchnorm=sync_batchnorm, + parallel_devices=(), + num_nodes=None, + cluster_environment=None, + sync_batchnorm=False, **kwargs ) From a98a046e23216d554ec0924feaa1e1e83df2d0b0 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:32:03 +0100 Subject: [PATCH 05/10] set devices etc. later if necessary --- pytorch_lightning/accelerators/accelerator_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 983ed1616ee1d..5278027636f6a 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -162,8 +162,6 @@ def handle_given_plugins(self, plugins: Optional[Sequence]): if training_type is None: training_type = plug - - else: raise MisconfigurationException( 'You can only specify one precision and one training type plugin. ' From 0a55ddaf9886a9a15130d857584247c2f1365604 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:32:28 +0100 Subject: [PATCH 06/10] remove devices from sequential --- pytorch_lightning/plugins/training_type/rpc_sequential.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 6ff058b2d6c21..cf02776eb5881 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -90,10 +90,6 @@ def __init__( """ self._check_pipe_available() super().__init__( - parallel_devices=(), - num_nodes=None, - cluster_environment=None, - sync_batchnorm=False, rpc_timeout_sec=rpc_timeout_sec, **kwargs ) From 77efba78322d96ec35e22dbb27ad70ec10b6baaa Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:32:47 +0100 Subject: [PATCH 07/10] make devices optional in rpc --- pytorch_lightning/plugins/training_type/rpc.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index d1510414eba2b..dc1c731da4ffa 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import Optional +from typing import Optional, Sequence import torch @@ -41,15 +41,19 @@ class RPCPlugin(DDPPlugin): def __init__( self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, + parallel_devices : Sequence[int] = (), + num_nodes: Optional[int] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + sync_batchnorm: Optional[bool] = None, **kwargs ): self.rpc_timeout_sec = rpc_timeout_sec self._is_rpc_initialized = False super().__init__( - parallel_devices=(), - num_nodes=None, - cluster_environment=None, - sync_batchnorm=False, + parallel_devices=parallel_devices, + num_nodes=num_nodes, + cluster_environment=cluster_environment, + sync_batchnorm=sync_batchnorm, **kwargs ) From ed9b39fdad691050ae7c23f5bdfbe963e160522c Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:32:55 +0100 Subject: [PATCH 08/10] fix import --- tests/plugins/legacy/test_rpc_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/legacy/test_rpc_plugin.py b/tests/plugins/legacy/test_rpc_plugin.py index 0bc17ca8ea2fa..a1e28d22ace58 100644 --- a/tests/plugins/legacy/test_rpc_plugin.py +++ b/tests/plugins/legacy/test_rpc_plugin.py @@ -7,7 +7,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCPlugin from pytorch_lightning.utilities import _RPC_AVAILABLE from tests.base.boring_model import BoringModel From 05c96253c04d49e1e3a30e566ae2134be416c2c7 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 17:33:07 +0100 Subject: [PATCH 09/10] uncomment everything --- tests/special_tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 18c3feb415b33..3da35696e44b7 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # Running special tests -#set -e +set -e export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp @@ -21,7 +21,7 @@ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection -# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance +python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp From 5c1aee55ecc0acc59bcea7929e903df47e97894c Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Sat, 6 Feb 2021 18:21:43 +0100 Subject: [PATCH 10/10] fix cluster selection --- pytorch_lightning/accelerators/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 5278027636f6a..3b3cfa50cc045 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -355,7 +355,7 @@ def select_training_type_plugin(self): plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, - cluster_environment=self._cluster_environment, + cluster_environment=self.select_cluster_environment(), sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: