Skip to content

Commit

Permalink
improve vram use w gradient checkpointing (#1167) [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 23, 2024
1 parent b8e5603 commit 802f966
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ def normalize_config(cfg):
if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]

if (
cfg.gradient_checkpointing
and cfg.unfrozen_parameters is None
and cfg.gradient_checkpointing_kwargs is None
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down

0 comments on commit 802f966

Please sign in to comment.