diff --git a/CHANGELOG.md b/CHANGELOG.md index 84a512202b0a01..dd8d54cde4929f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) +- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f9ccc7a42fa066..f93e95a14981ec 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, Optional, Union +from typing import Any, Callable, Iterable, Optional, Union, Sequence import torch from torch.optim import Optimizer @@ -53,21 +53,20 @@ def __init__( self.precision_plugin = precision_plugin self.training_type_plugin = training_type_plugin - self.optimizers = None - self.lr_schedulers = None - self.optimizer_frequencies = None + self.optimizers: Sequence = [] + self.lr_schedulers: Sequence = [] + self.optimizer_frequencies: Sequence = [] def setup(self, trainer, model: LightningModule) -> None: """ - Connects the plugins to the training process, creates optimizers - + Setup plugins for the trainer fit and creates optimizers. Args: - trainer: the trainer instance to connect to - model: the model to train + trainer: the trainer instance + model: the LightningModule """ - self.connect_training_type_plugin(self.training_type_plugin, model) + self.setup_training_type_plugin(self.training_type_plugin, model) self.setup_optimizers(trainer) - self.connect_precision_plugin(self.precision_plugin) + self.setup_precision_plugin(self.precision_plugin) def start_training(self, trainer): self.training_type_plugin.start_training(trainer) @@ -319,11 +318,8 @@ def setup_optimizers(self, trainer): self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: - """Attaches the training type plugin to the accelerator. - Also transfers ownership of the model to this plugin - - """ + def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """Attaches the training type plugin to the accelerator.""" plugin.connect(model) def connect_precision_plugin(self, plugin: PrecisionPlugin): diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 007f898a27cc7a..8a9fd7124f247e 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -86,9 +86,7 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs - def setup(self, model): - self._model = model - + def setup_environment(self): # start the other scripts # TODO: refactor and let generic cluster env hold the information about who spawns the processes if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": @@ -97,6 +95,8 @@ def setup(self, model): # set the task idx self.task_idx = self.cluster_environment.local_rank() + self.setup_distributed() + def _call_children_scripts(self): # bookkeeping of spawned processes @@ -171,6 +171,34 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) + def setup_distributed(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( @@ -226,37 +254,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - # TODO: check if needed - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # determine which process we are and world size - self.set_world_ranks() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) - - # TODO: we moved it to the trainer.fit after calling pre_dispatch - # ... need to double check that it is the correct place - # self.trainer.call_setup_hook(self.model) - - # on world_size=0 let everyone know training is starting - if self.is_global_zero and not torch.distributed.is_initialized(): - log.info("-" * 100) - log.info(f"distributed_backend={self.distributed_backend}") - log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") - log.info("-" * 100) - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 58fd1304209bba..775882bed79434 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -192,17 +192,7 @@ def _load_config(self, config): return config def pre_dispatch(self): - self.set_world_ranks() - self.init_ddp_connection(self.global_rank, self.world_size) - self.init_deepspeed() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device self.barrier() def init_deepspeed(self): diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 9809443aff3fb6..a25537c71503b4 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -60,14 +60,6 @@ def on_gpu(self): def lightning_module(self): return unwrap_lightning_module(self._model) - @abstractmethod - def setup(self, model): - raise NotImplementedError - - def connect(self, model, *args, **kwargs): - self.setup(model) - return self.model - @property def is_global_zero(self) -> bool: return self.global_rank == 0 diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 39fe06e1d46f2b..983a454f3d0913 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -64,8 +64,7 @@ def model_to_device(self) -> None: self._model.to(self.root_device) - def connect(self, model: torch.nn.Module) -> torch.nn.Module: - self._model = model + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.model_to_device() return self.model diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 3ddfd98128787a..e0c63a949308ed 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -26,13 +26,8 @@ def __init__(self, device: Union[torch.device, int]): def on_tpu(self) -> bool: return True - def connect(self, model: torch.nn.Module) -> torch.nn.Module: - self._model = model - self.model_to_device() - return self._model - def model_to_device(self) -> None: - self._model.to(self.root_device) + self.model.to(self.root_device) def pre_dispatch(self) -> None: if isinstance(self.device, int): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1e951329b22ccb..647d0e5e3229d1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -39,10 +39,9 @@ def __init__( self.tpu_local_core_rank = 0 self.start_method = None - def connect(self, model: torch.nn.Module) -> torch.nn.Module: + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() - self._model = model - return self._model + return self.model def create_mp_queue(self): self.start_method = 'fork' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f378ee830d2617..51ee475eac2dc4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -455,8 +455,10 @@ def fit( # ---------------------------- # SET UP TRAINING # ---------------------------- - self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) + self.accelerator.connect(model) + self.accelerator.setup_environment() + self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 9eaa4c7e2b57ea..75625c68a6cc57 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -95,7 +95,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): "SLURM_LOCALID": "10" } ) -def test_accelerator_choice_ddp_slurm(): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_slurm(setup_distributed_mock): class CB(Callback): @@ -133,7 +134,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp2_slurm(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -162,7 +164,8 @@ def on_fit_start(self, trainer, pl_module): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -190,7 +193,8 @@ def on_fit_start(self, trainer, pl_module): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp2_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -221,7 +225,8 @@ def on_fit_start(self, trainer, pl_module): "NODE_RANK": "0", }) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -256,7 +261,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_slurm(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -291,7 +297,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock): """ Test that we choose the custom cluster even when SLURM or TE flags are around """ @@ -301,6 +308,9 @@ class CustomCluster(ClusterEnvironment): def master_address(self): return 'asdf' + def creates_children(self) -> bool: + return True + class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -333,7 +343,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_custom_accelerator(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): pass @@ -368,7 +379,8 @@ class TrainTypePlugin(SingleDevicePlugin): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_dist_backend_accelerator_mapping(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock): class CB(Callback): diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index b582532cd710ea..566e70cd6a88d0 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -13,6 +13,8 @@ # limitations under the License. import os import platform +from typing import Optional +from unittest import mock from unittest.mock import patch import pytest @@ -98,7 +100,6 @@ def test_torch_distributed_backend_env_variables(tmpdir): _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), \ patch('torch.cuda.device_count', return_value=2): - with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): model = BoringModel() trainer = Trainer( @@ -109,3 +110,30 @@ def test_torch_distributed_backend_env_variables(tmpdir): logger=False, ) trainer.fit(model) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@mock.patch('torch.cuda.device_count', return_value=1) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.set_device') +@mock.patch.dict(os.environ, {'PL_TORCH_DISTRIBUTED_BACKEND': 'gloo'}, clear=True) +def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): + """ + Test to ensure torch distributed is available within the setup hook using ddp + """ + + class TestModel(BoringModel): + + def setup(self, stage: Optional[str] = None) -> None: + assert torch.distributed.is_initialized() + raise SystemExit() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ddp", + gpus=1, + ) + with pytest.raises(SystemExit): + trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8d01841f3636cf..ae07bf64c8e08b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -54,8 +54,8 @@ def test_trainer_callback_system(torch_save, tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'fit'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'fit'), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), @@ -111,7 +111,6 @@ def test_trainer_callback_system(torch_save, tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), call.on_fit_start(trainer, model), call.on_test_start(trainer, model), @@ -129,6 +128,42 @@ def test_trainer_callback_system(torch_save, tmpdir): ] +def test_trainer_callback_hook_system_validate(tmpdir): + """Test the callback hook system for validate.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_val_batches=2, + progress_bar_refresh_rate=0, + ) + + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'validate'), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_batch_start(trainer, model, ANY, 1, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_validation_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.teardown(trainer, model, 'validate'), + ] + + +# TODO: add callback tests for predict and tune + + def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ @@ -165,9 +200,11 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback): # .fit() trainer_options.update(callbacks=[trainer_callback_mock]) trainer = Trainer(**trainer_options) + assert trainer_callback_mock in trainer.callbacks assert model_callback_mock not in trainer.callbacks trainer.fit(model) + assert model_callback_mock in trainer.callbacks assert trainer.callbacks[-1] == model_callback_mock assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 8c4c7873681ad9..cb084734c5565e 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -151,7 +151,11 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_invalid_deepspeed_defaults_no_precision(tmpdir): """ Test to ensure that using defaults, if precision is not set to 16, we throw an exception. @@ -170,6 +174,9 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_warn_deepspeed_override_backward(tmpdir): """ Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning. @@ -193,6 +200,9 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_deepspeed_run_configure_optimizers(tmpdir): """ Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), @@ -222,6 +232,9 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_deepspeed_config(tmpdir, deepspeed_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers @@ -253,6 +266,9 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_deepspeed_custom_precision_params(tmpdir): """ Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes. @@ -269,11 +285,12 @@ def on_train_start(self) -> None: raise SystemExit() model = TestModel() - trainer = Trainer( - plugins=[ - DeepSpeedPlugin( + ds = DeepSpeedPlugin( loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10 ) + trainer = Trainer( + plugins=[ + ds ], precision=16, gpus=1 @@ -284,6 +301,9 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): """ Ensure if we use a config and turn off cpu_offload, that this is set to False within the config. @@ -298,7 +318,7 @@ def on_train_start(self) -> None: raise SystemExit() model = TestModel() - trainer = Trainer(plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1) + trainer = Trainer(plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1,) with pytest.raises(SystemExit): trainer.fit(model) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 43658721e92264..dd67af470c4ec5 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,6 +17,12 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_custom_precision_params +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_assert_config_zero_offload_disabled python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual