Skip to content

Commit

Permalink
Delay DeepSpeed config setup (Lightning-AI#19209)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Dec 24, 2023
1 parent 91ef190 commit 3518f9e
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 289 deletions.
283 changes: 137 additions & 146 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,20 +313,6 @@ def __init__(
self.hysteresis = hysteresis
self.min_loss_scale = min_loss_scale

def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
config = os.environ[self.DEEPSPEED_ENV_VAR]
if isinstance(config, (str, Path)):
if not os.path.isfile(config):
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
with open(config) as f:
config = json.load(f)
assert isinstance(config, dict) or config is None
return config

@override
def setup_environment(self) -> None:
if not isinstance(self.accelerator, CUDAAccelerator):
Expand All @@ -343,12 +329,10 @@ def setup_distributed(self) -> None:
reset_seed()
self.set_world_ranks()
self._init_deepspeed_distributed()
if not self._config_initialized:
self._format_config()
self._config_initialized = True

@override
def setup(self, trainer: "pl.Trainer") -> None:
self._init_config_if_needed()
assert self.accelerator is not None
self.accelerator.setup(trainer)
# we set the device so that optimizers can be created with distributed comms.
Expand Down Expand Up @@ -529,7 +513,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
def model_sharded_context(self) -> Generator[None, None, None]:
import deepspeed

assert self._config_initialized
self._init_config_if_needed()
with deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
Expand Down Expand Up @@ -610,134 +594,6 @@ def handles_gradient_accumulation(self) -> bool:
"""Whether the strategy handles gradient accumulation internally."""
return True

def _format_config(self) -> None:
if self.config is None:
raise MisconfigurationException(
"To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
" See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
)
self._format_batch_size_and_grad_accum_config()
_format_precision_config(
config=self.config,
precision=self.precision_plugin.precision,
loss_scale=self.loss_scale,
loss_scale_window=self.loss_scale_window,
min_loss_scale=self.min_loss_scale,
initial_scale_power=self.initial_scale_power,
hysteresis=self.hysteresis,
)

def _format_batch_size_and_grad_accum_config(self) -> None:
# TODO: Using Fabric, we do not support these variables within the config
assert isinstance(self.config, dict)
if self.lightning_module is None:
return

if "gradient_accumulation_steps" in self.config:
raise MisconfigurationException(
"Do not set `gradient_accumulation_steps` in the DeepSpeed config"
" as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer."
)
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "train_micro_batch_size_per_gpu" not in self.config:
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0

def _auto_select_batch_size(self) -> int:
import deepspeed

# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we try to use the batch size of the loader
assert self.lightning_module is not None
batch_size = 1
data_source = self.lightning_module.trainer.fit_loop._data_source
if data_source.is_defined():
try:
train_dataloader = data_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
# broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
# to have been called before
except Exception:
if self.global_rank == 0:
deepspeed.utils.logging.logger.warning(
"Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. "
"To ensure DeepSpeed logging remains correct, please manually pass the strategy with the "
"batch size, `Trainer(strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=batch_size))`."
)
return batch_size

def _create_default_config(
self,
zero_optimization: bool,
zero_allow_untested_optimizer: bool,
logging_batch_size_per_gpu: Union[str, int],
partition_activations: bool,
cpu_checkpointing: bool,
contiguous_memory_optimization: bool,
synchronize_checkpoint_boundary: bool,
offload_optimizer: bool,
offload_parameters: bool,
nvme_path: str,
offload_params_device: str,
params_buffer_count: int,
params_buffer_size: int,
max_in_cpu: int,
offload_optimizer_device: str,
optimizer_buffer_count: int,
pin_memory: bool,
block_size: int,
queue_depth: int,
single_submit: bool,
overlap_events: bool,
thread_count: int,
**zero_kwargs: Any,
) -> Dict:
cfg = {
"activation_checkpointing": {
"partition_activations": partition_activations,
"cpu_checkpointing": cpu_checkpointing,
"contiguous_memory_optimization": contiguous_memory_optimization,
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
},
"aio": {
"block_size": block_size,
"queue_depth": queue_depth,
"single_submit": single_submit,
"overlap_events": overlap_events,
"thread_count": thread_count,
},
}
if zero_optimization:
zero_config = zero_kwargs

if offload_optimizer:
zero_config["offload_optimizer"] = {
"device": offload_optimizer_device,
"nvme_path": nvme_path,
"buffer_count": optimizer_buffer_count,
"pin_memory": pin_memory,
}
if offload_parameters:
zero_config["offload_param"] = {
"device": offload_params_device,
"nvme_path": nvme_path,
"buffer_count": params_buffer_count,
"buffer_size": params_buffer_size,
"max_in_cpu": max_in_cpu,
"pin_memory": pin_memory,
}
cfg = {
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
"zero_optimization": zero_config,
**cfg,
}
if logging_batch_size_per_gpu != "auto":
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

@property
def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine":
return self.model
Expand Down Expand Up @@ -915,3 +771,138 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
offload_params_device="nvme",
offload_optimizer_device="nvme",
)

def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
config = os.environ[self.DEEPSPEED_ENV_VAR]
if isinstance(config, (str, Path)):
if not os.path.isfile(config):
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
with open(config) as f:
config = json.load(f)
assert isinstance(config, dict) or config is None
return config

def _init_config_if_needed(self) -> None:
if not self._config_initialized:
self._format_config()
self._config_initialized = True

def _format_config(self) -> None:
if self.config is None:
raise MisconfigurationException(
"To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
" See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
)
self._format_batch_size_and_grad_accum_config()
_format_precision_config(
config=self.config,
precision=self.precision_plugin.precision,
loss_scale=self.loss_scale,
loss_scale_window=self.loss_scale_window,
min_loss_scale=self.min_loss_scale,
initial_scale_power=self.initial_scale_power,
hysteresis=self.hysteresis,
)

def _create_default_config(
self,
zero_optimization: bool,
zero_allow_untested_optimizer: bool,
logging_batch_size_per_gpu: Union[str, int],
partition_activations: bool,
cpu_checkpointing: bool,
contiguous_memory_optimization: bool,
synchronize_checkpoint_boundary: bool,
offload_optimizer: bool,
offload_parameters: bool,
nvme_path: str,
offload_params_device: str,
params_buffer_count: int,
params_buffer_size: int,
max_in_cpu: int,
offload_optimizer_device: str,
optimizer_buffer_count: int,
pin_memory: bool,
block_size: int,
queue_depth: int,
single_submit: bool,
overlap_events: bool,
thread_count: int,
**zero_kwargs: Any,
) -> Dict:
cfg = {
"activation_checkpointing": {
"partition_activations": partition_activations,
"cpu_checkpointing": cpu_checkpointing,
"contiguous_memory_optimization": contiguous_memory_optimization,
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
},
"aio": {
"block_size": block_size,
"queue_depth": queue_depth,
"single_submit": single_submit,
"overlap_events": overlap_events,
"thread_count": thread_count,
},
}
if zero_optimization:
zero_config = zero_kwargs

if offload_optimizer:
zero_config["offload_optimizer"] = {
"device": offload_optimizer_device,
"nvme_path": nvme_path,
"buffer_count": optimizer_buffer_count,
"pin_memory": pin_memory,
}
if offload_parameters:
zero_config["offload_param"] = {
"device": offload_params_device,
"nvme_path": nvme_path,
"buffer_count": params_buffer_count,
"buffer_size": params_buffer_size,
"max_in_cpu": max_in_cpu,
"pin_memory": pin_memory,
}
cfg = {
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
"zero_optimization": zero_config,
**cfg,
}
if logging_batch_size_per_gpu != "auto":
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

def _format_batch_size_and_grad_accum_config(self) -> None:
# TODO: Using Fabric, we do not support these variables within the config
assert isinstance(self.config, dict)
if self.lightning_module is None:
return

if "gradient_accumulation_steps" in self.config:
raise MisconfigurationException(
"Do not set `gradient_accumulation_steps` in the DeepSpeed config"
" as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer."
)
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "train_micro_batch_size_per_gpu" not in self.config:
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0

def _auto_select_batch_size(self) -> int:
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we try to use the batch size of the loader
assert self.lightning_module is not None
batch_size = 1
data_source = self.lightning_module.trainer.fit_loop._data_source
if data_source.is_defined():
train_dataloader = data_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
return batch_size
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def _run(
self.strategy.setup_environment()
self.__setup_profiler()

call._call_setup_hook(self) # allow user to setup lightning_module in accelerator environment
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
log.debug(f"{self.__class__.__name__}: configuring model")
call._call_configure_model(self)

Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,10 @@ def training_step(self, batch, batch_idx):
expected = [
{"name": "configure_callbacks"},
{"name": "prepare_data"},
# DeepSpeed needs the batch size to figure out throughput logging
*([{"name": "train_dataloader"}] if using_deepspeed else []),
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
{"name": "setup", "kwargs": {"stage": "fit"}},
# DeepSpeed needs the batch size to figure out throughput logging
*([{"name": "train_dataloader"}] if using_deepspeed else []),
{"name": "configure_model"},
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
Expand Down
Loading

0 comments on commit 3518f9e

Please sign in to comment.