From 3518f9e09284099ddf623fe5ba9025a78b32397f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 24 Dec 2023 23:04:04 +0100 Subject: [PATCH] Delay DeepSpeed config setup (#19209) --- src/lightning/pytorch/strategies/deepspeed.py | 283 +++++++++--------- src/lightning/pytorch/trainer/trainer.py | 2 +- tests/tests_pytorch/models/test_hooks.py | 4 +- .../strategies/test_deepspeed.py | 236 ++++++--------- 4 files changed, 236 insertions(+), 289 deletions(-) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 3ad43125b3ad2..4e7a3bb122a55 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -313,20 +313,6 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale - def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: - if config is None and self.DEEPSPEED_ENV_VAR in os.environ: - rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") - config = os.environ[self.DEEPSPEED_ENV_VAR] - if isinstance(config, (str, Path)): - if not os.path.isfile(config): - raise MisconfigurationException( - f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" - ) - with open(config) as f: - config = json.load(f) - assert isinstance(config, dict) or config is None - return config - @override def setup_environment(self) -> None: if not isinstance(self.accelerator, CUDAAccelerator): @@ -343,12 +329,10 @@ def setup_distributed(self) -> None: reset_seed() self.set_world_ranks() self._init_deepspeed_distributed() - if not self._config_initialized: - self._format_config() - self._config_initialized = True @override def setup(self, trainer: "pl.Trainer") -> None: + self._init_config_if_needed() assert self.accelerator is not None self.accelerator.setup(trainer) # we set the device so that optimizers can be created with distributed comms. @@ -529,7 +513,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No def model_sharded_context(self) -> Generator[None, None, None]: import deepspeed - assert self._config_initialized + self._init_config_if_needed() with deepspeed.zero.Init( enabled=self.zero_stage_3, remote_device=self.remote_device, @@ -610,134 +594,6 @@ def handles_gradient_accumulation(self) -> bool: """Whether the strategy handles gradient accumulation internally.""" return True - def _format_config(self) -> None: - if self.config is None: - raise MisconfigurationException( - "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." - " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed" - ) - self._format_batch_size_and_grad_accum_config() - _format_precision_config( - config=self.config, - precision=self.precision_plugin.precision, - loss_scale=self.loss_scale, - loss_scale_window=self.loss_scale_window, - min_loss_scale=self.min_loss_scale, - initial_scale_power=self.initial_scale_power, - hysteresis=self.hysteresis, - ) - - def _format_batch_size_and_grad_accum_config(self) -> None: - # TODO: Using Fabric, we do not support these variables within the config - assert isinstance(self.config, dict) - if self.lightning_module is None: - return - - if "gradient_accumulation_steps" in self.config: - raise MisconfigurationException( - "Do not set `gradient_accumulation_steps` in the DeepSpeed config" - " as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer." - ) - self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches - if "train_micro_batch_size_per_gpu" not in self.config: - batch_size = self._auto_select_batch_size() - self.config["train_micro_batch_size_per_gpu"] = batch_size - if "gradient_clipping" not in self.config: - self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0 - - def _auto_select_batch_size(self) -> int: - import deepspeed - - # train_micro_batch_size_per_gpu is used for throughput logging purposes - # by default we try to use the batch size of the loader - assert self.lightning_module is not None - batch_size = 1 - data_source = self.lightning_module.trainer.fit_loop._data_source - if data_source.is_defined(): - try: - train_dataloader = data_source.dataloader() - if hasattr(train_dataloader, "batch_sampler"): - batch_size = train_dataloader.batch_sampler.batch_size - # broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup` - # to have been called before - except Exception: - if self.global_rank == 0: - deepspeed.utils.logging.logger.warning( - "Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. " - "To ensure DeepSpeed logging remains correct, please manually pass the strategy with the " - "batch size, `Trainer(strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=batch_size))`." - ) - return batch_size - - def _create_default_config( - self, - zero_optimization: bool, - zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Union[str, int], - partition_activations: bool, - cpu_checkpointing: bool, - contiguous_memory_optimization: bool, - synchronize_checkpoint_boundary: bool, - offload_optimizer: bool, - offload_parameters: bool, - nvme_path: str, - offload_params_device: str, - params_buffer_count: int, - params_buffer_size: int, - max_in_cpu: int, - offload_optimizer_device: str, - optimizer_buffer_count: int, - pin_memory: bool, - block_size: int, - queue_depth: int, - single_submit: bool, - overlap_events: bool, - thread_count: int, - **zero_kwargs: Any, - ) -> Dict: - cfg = { - "activation_checkpointing": { - "partition_activations": partition_activations, - "cpu_checkpointing": cpu_checkpointing, - "contiguous_memory_optimization": contiguous_memory_optimization, - "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, - }, - "aio": { - "block_size": block_size, - "queue_depth": queue_depth, - "single_submit": single_submit, - "overlap_events": overlap_events, - "thread_count": thread_count, - }, - } - if zero_optimization: - zero_config = zero_kwargs - - if offload_optimizer: - zero_config["offload_optimizer"] = { - "device": offload_optimizer_device, - "nvme_path": nvme_path, - "buffer_count": optimizer_buffer_count, - "pin_memory": pin_memory, - } - if offload_parameters: - zero_config["offload_param"] = { - "device": offload_params_device, - "nvme_path": nvme_path, - "buffer_count": params_buffer_count, - "buffer_size": params_buffer_size, - "max_in_cpu": max_in_cpu, - "pin_memory": pin_memory, - } - cfg = { - "zero_allow_untested_optimizer": zero_allow_untested_optimizer, - "zero_optimization": zero_config, - **cfg, - } - if logging_batch_size_per_gpu != "auto": - cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} - return cfg - @property def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine": return self.model @@ -915,3 +771,138 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_params_device="nvme", offload_optimizer_device="nvme", ) + + def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: + if config is None and self.DEEPSPEED_ENV_VAR in os.environ: + rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") + config = os.environ[self.DEEPSPEED_ENV_VAR] + if isinstance(config, (str, Path)): + if not os.path.isfile(config): + raise MisconfigurationException( + f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" + ) + with open(config) as f: + config = json.load(f) + assert isinstance(config, dict) or config is None + return config + + def _init_config_if_needed(self) -> None: + if not self._config_initialized: + self._format_config() + self._config_initialized = True + + def _format_config(self) -> None: + if self.config is None: + raise MisconfigurationException( + "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." + " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed" + ) + self._format_batch_size_and_grad_accum_config() + _format_precision_config( + config=self.config, + precision=self.precision_plugin.precision, + loss_scale=self.loss_scale, + loss_scale_window=self.loss_scale_window, + min_loss_scale=self.min_loss_scale, + initial_scale_power=self.initial_scale_power, + hysteresis=self.hysteresis, + ) + + def _create_default_config( + self, + zero_optimization: bool, + zero_allow_untested_optimizer: bool, + logging_batch_size_per_gpu: Union[str, int], + partition_activations: bool, + cpu_checkpointing: bool, + contiguous_memory_optimization: bool, + synchronize_checkpoint_boundary: bool, + offload_optimizer: bool, + offload_parameters: bool, + nvme_path: str, + offload_params_device: str, + params_buffer_count: int, + params_buffer_size: int, + max_in_cpu: int, + offload_optimizer_device: str, + optimizer_buffer_count: int, + pin_memory: bool, + block_size: int, + queue_depth: int, + single_submit: bool, + overlap_events: bool, + thread_count: int, + **zero_kwargs: Any, + ) -> Dict: + cfg = { + "activation_checkpointing": { + "partition_activations": partition_activations, + "cpu_checkpointing": cpu_checkpointing, + "contiguous_memory_optimization": contiguous_memory_optimization, + "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, + }, + "aio": { + "block_size": block_size, + "queue_depth": queue_depth, + "single_submit": single_submit, + "overlap_events": overlap_events, + "thread_count": thread_count, + }, + } + if zero_optimization: + zero_config = zero_kwargs + + if offload_optimizer: + zero_config["offload_optimizer"] = { + "device": offload_optimizer_device, + "nvme_path": nvme_path, + "buffer_count": optimizer_buffer_count, + "pin_memory": pin_memory, + } + if offload_parameters: + zero_config["offload_param"] = { + "device": offload_params_device, + "nvme_path": nvme_path, + "buffer_count": params_buffer_count, + "buffer_size": params_buffer_size, + "max_in_cpu": max_in_cpu, + "pin_memory": pin_memory, + } + cfg = { + "zero_allow_untested_optimizer": zero_allow_untested_optimizer, + "zero_optimization": zero_config, + **cfg, + } + if logging_batch_size_per_gpu != "auto": + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} + return cfg + + def _format_batch_size_and_grad_accum_config(self) -> None: + # TODO: Using Fabric, we do not support these variables within the config + assert isinstance(self.config, dict) + if self.lightning_module is None: + return + + if "gradient_accumulation_steps" in self.config: + raise MisconfigurationException( + "Do not set `gradient_accumulation_steps` in the DeepSpeed config" + " as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer." + ) + self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches + if "train_micro_batch_size_per_gpu" not in self.config: + batch_size = self._auto_select_batch_size() + self.config["train_micro_batch_size_per_gpu"] = batch_size + if "gradient_clipping" not in self.config: + self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0 + + def _auto_select_batch_size(self) -> int: + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we try to use the batch size of the loader + assert self.lightning_module is not None + batch_size = 1 + data_source = self.lightning_module.trainer.fit_loop._data_source + if data_source.is_defined(): + train_dataloader = data_source.dataloader() + if hasattr(train_dataloader, "batch_sampler"): + batch_size = train_dataloader.batch_sampler.batch_size + return batch_size diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6be7cb191157a..cc335c975e918 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -946,7 +946,7 @@ def _run( self.strategy.setup_environment() self.__setup_profiler() - call._call_setup_hook(self) # allow user to setup lightning_module in accelerator environment + call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment log.debug(f"{self.__class__.__name__}: configuring model") call._call_configure_model(self) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 978fe0ab6b740..9d02d37368c46 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -477,10 +477,10 @@ def training_step(self, batch, batch_idx): expected = [ {"name": "configure_callbacks"}, {"name": "prepare_data"}, - # DeepSpeed needs the batch size to figure out throughput logging - *([{"name": "train_dataloader"}] if using_deepspeed else []), {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, + # DeepSpeed needs the batch size to figure out throughput logging + *([{"name": "train_dataloader"}] if using_deepspeed else []), {"name": "configure_model"}, {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index c735b261ca2c6..93c9ee9f25183 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib import json -import logging import os from re import escape from typing import Any, Dict @@ -23,7 +22,7 @@ import pytest import torch import torch.nn.functional as F -from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset @@ -105,13 +104,13 @@ def deepspeed_zero_config(deepspeed_config): @RunIf(deepspeed=True) @pytest.mark.parametrize("strategy", ["deepspeed", DeepSpeedStrategy]) -def test_deepspeed_strategy_string(tmpdir, strategy): +def test_deepspeed_strategy_string(tmp_path, strategy): """Test to ensure that the strategy can be passed via string or instance, and parallel devices is correctly set.""" trainer = Trainer( accelerator="cpu", fast_dev_run=True, - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=strategy if isinstance(strategy, str) else strategy(), ) @@ -120,14 +119,14 @@ def test_deepspeed_strategy_string(tmpdir, strategy): @RunIf(deepspeed=True) -def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): +def test_deepspeed_strategy_env(tmp_path, monkeypatch, deepspeed_config): """Test to ensure that the strategy can be passed via a string with an environment variable.""" - config_path = os.path.join(tmpdir, "temp.json") + config_path = os.path.join(tmp_path, "temp.json") with open(config_path, "w") as f: f.write(json.dumps(deepspeed_config)) monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) - trainer = Trainer(accelerator="cpu", fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed") + trainer = Trainer(accelerator="cpu", fast_dev_run=True, default_root_dir=tmp_path, strategy="deepspeed") strategy = trainer.strategy assert isinstance(strategy, DeepSpeedStrategy) @@ -136,7 +135,7 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True, mps=False) -def test_deepspeed_precision_choice(cuda_count_1, tmpdir): +def test_deepspeed_precision_choice(cuda_count_1, tmp_path): """Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via Custom DeepSpeedPrecision @@ -144,7 +143,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmpdir): """ trainer = Trainer( fast_dev_run=True, - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", strategy="deepspeed", precision="16-mixed", @@ -165,9 +164,9 @@ def test_deepspeed_with_invalid_config_path(): @RunIf(deepspeed=True) -def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): +def test_deepspeed_with_env_path(tmp_path, monkeypatch, deepspeed_config): """Test to ensure if we pass an env variable, we load the config from the path.""" - config_path = os.path.join(tmpdir, "temp.json") + config_path = os.path.join(tmp_path, "temp.json") with open(config_path, "w") as f: f.write(json.dumps(deepspeed_config)) monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) @@ -184,7 +183,7 @@ def test_deepspeed_defaults(): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_warn_deepspeed_ignored(tmpdir): +def test_warn_deepspeed_ignored(tmp_path): class TestModel(BoringModel): def backward(self, loss: Tensor, *args, **kwargs) -> None: return loss.backward() @@ -192,7 +191,7 @@ def backward(self, loss: Tensor, *args, **kwargs) -> None: model = TestModel() trainer = Trainer( fast_dev_run=True, - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(), accelerator="gpu", devices=1, @@ -211,33 +210,30 @@ def backward(self, loss: Tensor, *args, **kwargs) -> None: ) @mock.patch("deepspeed.init_distributed", autospec=True) @mock.patch("lightning.pytorch.Trainer.log_dir", new_callable=mock.PropertyMock, return_value="abc") -def test_deepspeed_auto_batch_size_config_select(mock_deepspeed_distributed, mock_log_dir, tmpdir, dataset_cls, value): +def test_deepspeed_auto_batch_size_config_select(_, __, tmp_path, dataset_cls, value): """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" class TestModel(BoringModel): def train_dataloader(self): return DataLoader(dataset_cls(32, 64)) - class AssertCallback(Callback): - def setup(self, trainer, pl_module, stage: str) -> None: - assert isinstance(trainer.strategy, DeepSpeedStrategy) - config = trainer.strategy.config + def configure_model(self) -> None: + assert isinstance(self.trainer.strategy, DeepSpeedStrategy) + config = self.trainer.strategy.config # int value overrides auto mode expected_value = value if isinstance(value, int) else 1 if dataset_cls == RandomDataset: - expected_value = pl_module.train_dataloader().batch_size if value == "auto" else value + expected_value = self.train_dataloader().batch_size if value == "auto" else value assert config["train_micro_batch_size_per_gpu"] == expected_value raise SystemExit - ck = AssertCallback() model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, fast_dev_run=True, - callbacks=ck, - accelerator="gpu", + accelerator="cuda", devices=1, strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=value, zero_optimization=False), ) @@ -246,7 +242,7 @@ def setup(self, trainer, pl_module, stage: str) -> None: @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_run_configure_optimizers(tmpdir): +def test_deepspeed_run_configure_optimizers(tmp_path): """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers.""" @@ -267,13 +263,13 @@ def configure_optimizers(self): lr_monitor = LearningRateMonitor() trainer = Trainer( strategy=DeepSpeedStrategy(), # disable ZeRO so our optimizers are not wrapped - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", devices=1, fast_dev_run=True, precision="16-mixed", callbacks=[TestCB(), lr_monitor], - logger=CSVLogger(tmpdir), + logger=CSVLogger(tmp_path), enable_progress_bar=False, enable_model_summary=False, ) @@ -281,11 +277,11 @@ def configure_optimizers(self): assert lr_monitor.lrs == {"Sean": [0.1]} - _assert_save_model_is_equal(model, tmpdir, trainer) + _assert_save_model_is_equal(model, tmp_path, trainer) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_config(tmpdir, deepspeed_zero_config): +def test_deepspeed_config(tmp_path, deepspeed_zero_config): """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers and saves the model weights to load correctly.""" @@ -302,7 +298,7 @@ def on_train_start(self, trainer, pl_module) -> None: lr_monitor = LearningRateMonitor() trainer = Trainer( strategy=DeepSpeedStrategy(config=deepspeed_zero_config), - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", devices=1, log_every_n_steps=1, @@ -312,7 +308,7 @@ def on_train_start(self, trainer, pl_module) -> None: max_epochs=2, precision="16-mixed", callbacks=[TestCB(), lr_monitor], - logger=CSVLogger(tmpdir), + logger=CSVLogger(tmp_path), enable_progress_bar=False, enable_model_summary=False, ) @@ -324,7 +320,7 @@ def on_train_start(self, trainer, pl_module) -> None: @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_custom_precision_params(tmpdir): +def test_deepspeed_custom_precision_params(tmp_path): """Ensure if we modify the FP16 parameters via the DeepSpeedStrategy, the deepspeed config contains these changes.""" @@ -342,7 +338,7 @@ def on_train_start(self, trainer, pl_module) -> None: loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14 ) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=ds, precision="16-mixed", accelerator="gpu", @@ -357,7 +353,7 @@ def on_train_start(self, trainer, pl_module) -> None: @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) @pytest.mark.parametrize("precision", ["fp16", "bf16"]) -def test_deepspeed_inference_precision_during_inference(precision, tmpdir): +def test_deepspeed_inference_precision_during_inference(precision, tmp_path): """Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains these changes.""" @@ -370,7 +366,7 @@ def on_validation_start(self, trainer, pl_module) -> None: strategy = DeepSpeedStrategy(config={precision: {"enabled": True}}) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=strategy, accelerator="cuda", devices=1, @@ -382,7 +378,7 @@ def on_validation_start(self, trainer, pl_module) -> None: @RunIf(deepspeed=True) -def test_deepspeed_custom_activation_checkpointing_params(tmpdir): +def test_deepspeed_custom_activation_checkpointing_params(): """Ensure if we modify the activation checkpointing parameters, the deepspeed config contains these changes.""" ds = DeepSpeedStrategy( partition_activations=True, @@ -398,7 +394,7 @@ def test_deepspeed_custom_activation_checkpointing_params(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir): +def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmp_path): """Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure correctly.""" ds = DeepSpeedStrategy( @@ -410,7 +406,7 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir): model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, fast_dev_run=1, strategy=ds, precision="16-mixed", @@ -430,7 +426,7 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir): @RunIf(min_cuda_gpus=1, deepspeed=True) -def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): +def test_deepspeed_assert_config_zero_offload_disabled(tmp_path, deepspeed_zero_config): """Ensure if we use a config and turn off offload_optimizer, that this is set to False within the config.""" deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False @@ -441,7 +437,7 @@ def setup(self, trainer, pl_module, stage=None) -> None: model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, enable_progress_bar=False, max_epochs=1, strategy=DeepSpeedStrategy(config=deepspeed_zero_config), @@ -455,11 +451,11 @@ def setup(self, trainer, pl_module, stage=None) -> None: @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu(tmpdir): +def test_deepspeed_multigpu(tmp_path): """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly.""" model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, @@ -479,14 +475,14 @@ def test_deepspeed_multigpu(tmpdir): trainer.fit(model) mock_deepspeed_distributed.assert_called_once() - _assert_save_model_is_equal(model, tmpdir, trainer) + _assert_save_model_is_equal(model, tmp_path, trainer) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_fp32_works(tmpdir): +def test_deepspeed_fp32_works(tmp_path): model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", devices=1, strategy="deepspeed_stage_3", @@ -498,11 +494,11 @@ def test_deepspeed_fp32_works(tmpdir): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_stage_3_save_warning(tmpdir): +def test_deepspeed_stage_3_save_warning(tmp_path): """Test to ensure that DeepSpeed Stage 3 gives a warning when saving on rank zero.""" model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, @@ -512,7 +508,7 @@ def test_deepspeed_stage_3_save_warning(tmpdir): enable_model_summary=False, ) trainer.fit(model) - checkpoint_path = os.path.join(tmpdir, "model.pt") + checkpoint_path = os.path.join(tmp_path, "model.pt") # both ranks need to call save checkpoint, however only rank 0 needs to check the warning context_manager = ( @@ -525,16 +521,16 @@ def test_deepspeed_stage_3_save_warning(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_single_file(tmpdir): +def test_deepspeed_multigpu_single_file(tmp_path): """Test to ensure that DeepSpeed loads from a single file checkpoint.""" model = BoringModel() - checkpoint_path = os.path.join(tmpdir, "model.pt") - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="cpu", devices=1) + checkpoint_path = os.path.join(tmp_path, "model.pt") + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, accelerator="cpu", devices=1) trainer.fit(model) trainer.save_checkpoint(checkpoint_path) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, @@ -550,7 +546,7 @@ def test_deepspeed_multigpu_single_file(tmpdir): trainer.test(model, ckpt_path=checkpoint_path) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3, load_full_weights=True), accelerator="gpu", devices=1, @@ -653,11 +649,11 @@ def training_step(self, batch, batch_idx): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_stage_3(tmpdir): +def test_deepspeed_multigpu_stage_3(tmp_path): """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, @@ -669,15 +665,15 @@ def test_deepspeed_multigpu_stage_3(tmpdir): trainer.test(model) trainer.fit(model) - _assert_save_model_is_equal(model, tmpdir, trainer) + _assert_save_model_is_equal(model, tmp_path, trainer) @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config): +def test_deepspeed_multigpu_stage_3_manual_optimization(tmp_path, deepspeed_config): """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModelManualOptim() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, @@ -689,18 +685,18 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config trainer.test(model) trainer.fit(model) - _assert_save_model_is_equal(model, tmpdir, trainer) + _assert_save_model_is_equal(model, tmp_path, trainer) @pytest.mark.xfail(strict=False, reason="skipped due to deepspeed/#2449, keep track @rohitgr7") @pytest.mark.parametrize(("accumulate_grad_batches", "automatic_optimization"), [(1, False), (2, True)]) @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, sklearn=True) -def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization, accumulate_grad_batches): +def test_deepspeed_multigpu_stage_3_checkpointing(tmp_path, automatic_optimization, accumulate_grad_batches): model = ModelParallelClassificationModel() if automatic_optimization else ManualModelParallelClassificationModel() dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, max_epochs=10, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", @@ -719,7 +715,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization model = ModelParallelClassificationModel() if automatic_optimization else ManualModelParallelClassificationModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy=DeepSpeedStrategy(stage=3), @@ -731,14 +727,14 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True, sklearn=True) -def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): +def test_deepspeed_multigpu_stage_3_warns_resume_training(tmp_path): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored.""" dm = ClassifDataModule() model = BoringModel() - checkpoint_path = os.path.join(tmpdir, "model.pt") + checkpoint_path = os.path.join(tmp_path, "model.pt") trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, fast_dev_run=True, enable_progress_bar=False, enable_model_summary=False, @@ -749,7 +745,7 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): trainer.save_checkpoint(checkpoint_path) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, fast_dev_run=True, strategy=DeepSpeedStrategy(stage=3, load_full_weights=True), accelerator="gpu", @@ -768,14 +764,14 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True, sklearn=True) -def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): +def test_deepspeed_multigpu_stage_3_resume_training(tmp_path): """Test to ensure with Stage 3 and single GPU that we can resume training.""" initial_model = ModelParallelClassificationModel() dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) initial_trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, limit_val_batches=2, @@ -817,7 +813,7 @@ def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> model = ModelParallelClassificationModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, @@ -834,7 +830,7 @@ def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> @pytest.mark.parametrize("offload_optimizer", [False, True]) @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, sklearn=True) -def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer): +def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmp_path, offload_optimizer): """Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.""" class VerificationCallback(Callback): @@ -852,7 +848,7 @@ def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, strategy = DeepSpeedStrategy(stage=2, offload_optimizer=offload_optimizer) strategy.config["zero_force_ds_cpu_optimizer"] = False trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, # TODO: this test fails with max_epochs >1 as there are leftover batches per epoch. # there's divergence in how Lightning handles the last batch of the epoch with how DeepSpeed does it. # we step the optimizers on the last batch but DeepSpeed keeps the accumulation for the next epoch @@ -874,11 +870,11 @@ def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_test(tmpdir): +def test_deepspeed_multigpu_test(tmp_path): """Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.""" model = ModelParallelBoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, @@ -893,7 +889,7 @@ def test_deepspeed_multigpu_test(tmpdir): # TODO(Sean): Once partial parameter partitioning is supported this test should be re-enabled @pytest.mark.xfail(strict=False, reason="Partial parameter partitioning for DeepSpeed is currently broken.") @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): +def test_deepspeed_multigpu_partial_partition_parameters(tmp_path): """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_model`` correctly converts all parameters to float16 when ``precision=16`` and runs successfully.""" @@ -915,7 +911,7 @@ def on_train_epoch_start(self) -> None: model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, @@ -928,7 +924,7 @@ def on_train_epoch_start(self) -> None: @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_test_rnn(tmpdir): +def test_deepspeed_multigpu_test_rnn(tmp_path): """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when training with certain layers which will crash with explicit partitioning.""" @@ -942,7 +938,7 @@ def on_train_epoch_start(self) -> None: model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, @@ -957,13 +953,13 @@ def on_train_epoch_start(self) -> None: @RunIf(deepspeed=True, mps=False) @mock.patch("deepspeed.init_distributed", autospec=True) @pytest.mark.parametrize("platform", ["Linux", "Windows"]) -def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmpdir, platform): +def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path, platform): """Test to ensure that we setup distributed communication using correctly. When using windows, ranks environment variables should not be set, and deepspeed should handle this. """ - trainer = Trainer(default_root_dir=tmpdir, strategy=DeepSpeedStrategy(stage=3)) + trainer = Trainer(default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3)) strategy = trainer.strategy assert isinstance(strategy, DeepSpeedStrategy) with mock.patch("platform.system", return_value=platform) as mock_platform: @@ -981,14 +977,14 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmpdir, pl assert os.environ["LOCAL_RANK"] == str(trainer.strategy.local_rank) -def _assert_save_model_is_equal(model, tmpdir, trainer): - checkpoint_path = os.path.join(tmpdir, "model.pt") +def _assert_save_model_is_equal(model, tmp_path, trainer): + checkpoint_path = os.path.join(tmp_path, "model.pt") checkpoint_path = trainer.strategy.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) # carry out the check only on rank 0 if trainer.is_global_zero: - single_ckpt_path = os.path.join(tmpdir, "single_model.pt") + single_ckpt_path = os.path.join(tmp_path, "single_model.pt") convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path) state_dict = torch.load(single_ckpt_path) @@ -1002,11 +998,11 @@ def _assert_save_model_is_equal(model, tmpdir, trainer): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_no_schedulers(tmpdir): +def test_deepspeed_multigpu_no_schedulers(tmp_path): """Test to ensure ZeRO Stage 3 works with a parallel model and no schedulers.""" model = ModelParallelBoringModelNoSchedulers() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, @@ -1017,18 +1013,18 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): ) trainer.fit(model) - _assert_save_model_is_equal(model, tmpdir, trainer) + _assert_save_model_is_equal(model, tmp_path, trainer) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_skip_backward_raises(tmpdir): +def test_deepspeed_skip_backward_raises(tmp_path): class TestModel(BoringModel): def training_step(self, batch, batch_idx): return None model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy=DeepSpeedStrategy(), accelerator="gpu", devices=1, @@ -1041,52 +1037,12 @@ def training_step(self, batch, batch_idx): trainer.fit(model) -@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_setup_train_dataloader(tmpdir): - """Test DeepSpeed works when setup is required to call in the DataModule.""" - - class TestSetupIsCalledDataModule(LightningDataModule): - def __init__(self): - super().__init__() - self._setup = False - - def setup(self, stage: str) -> None: - self._setup = True - - def train_dataloader(self): - assert self._setup - return DataLoader(RandomDataset(32, 64), batch_size=2) - - def val_dataloader(self): - assert self._setup - return DataLoader(RandomDataset(32, 64), batch_size=2) - - def test_dataloader(self): - assert self._setup - return DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - strategy=DeepSpeedStrategy(logging_level=logging.INFO), - accelerator="gpu", - devices=1, - fast_dev_run=True, - enable_progress_bar=False, - enable_model_summary=False, - ) - dm = TestSetupIsCalledDataModule() - with mock.patch("deepspeed.utils.logging.logger.warning", autospec=True) as mock_object: - trainer.fit(model, datamodule=dm) - assert any("Tried to infer the batch size" in str(arg) for arg in mock_object.call_args_list) - - @mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) @pytest.mark.parametrize("interval", ["step", "epoch"]) @pytest.mark.parametrize("max_epoch", [2]) @pytest.mark.parametrize("limit_train_batches", [2]) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_scheduler_step_count(mock_step, tmpdir, max_epoch, limit_train_batches, interval): +def test_scheduler_step_count(mock_step, tmp_path, max_epoch, limit_train_batches, interval): """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is set to step or epoch.""" @@ -1101,7 +1057,7 @@ def configure_optimizers(self): model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, limit_train_batches=limit_train_batches, limit_val_batches=0, max_epochs=max_epoch, @@ -1121,7 +1077,7 @@ def configure_optimizers(self): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_configure_gradient_clipping(tmpdir): +def test_deepspeed_configure_gradient_clipping(tmp_path): """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in case of deepspeed.""" @@ -1131,7 +1087,7 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", devices=1, strategy="deepspeed", @@ -1144,11 +1100,11 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) -def test_deepspeed_gradient_clip_by_value(tmpdir): +def test_deepspeed_gradient_clip_by_value(tmp_path): """Test to ensure that an exception is raised when using `gradient_clip_algorithm='value'`.""" model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="gpu", devices=1, strategy="deepspeed", @@ -1161,7 +1117,7 @@ def test_deepspeed_gradient_clip_by_value(tmpdir): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multi_save_same_filepath(tmpdir): +def test_deepspeed_multi_save_same_filepath(tmp_path): """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old sharded checkpoints.""" @@ -1172,7 +1128,7 @@ def training_step(self, *args, **kwargs): model = CustomModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy="deepspeed", accelerator="gpu", devices=2, @@ -1196,11 +1152,11 @@ def training_step(self, *args, **kwargs): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_with_bfloat16_precision(tmpdir): +def test_deepspeed_with_bfloat16_precision(tmp_path): """Test that deepspeed works with bfloat16 precision.""" model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, strategy="deepspeed_stage_3", accelerator="gpu", devices=2, @@ -1220,10 +1176,10 @@ def test_deepspeed_with_bfloat16_precision(tmpdir): @RunIf(deepspeed=True) -def test_error_with_invalid_accelerator(tmpdir): +def test_error_with_invalid_accelerator(tmp_path): """Test DeepSpeedStrategy raises an exception if an invalid accelerator is used.""" trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, accelerator="cpu", strategy="deepspeed", fast_dev_run=True, @@ -1234,7 +1190,7 @@ def test_error_with_invalid_accelerator(tmpdir): @RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True) -def test_deepspeed_configure_optimizer_device_set(tmpdir): +def test_deepspeed_configure_optimizer_device_set(tmp_path): """Test to ensure that the LM has access to the device within the ``configure_optimizer`` function, and estimated_stepping_batches works correctly as a result.""" @@ -1246,7 +1202,7 @@ def configure_optimizers(self): model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, fast_dev_run=True, accelerator="gpu", devices=2,