Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: better device mapping for large models #918

Merged
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,11 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere
float16: true

# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true

# A list of one or more datasets to finetune the model with
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
Expand Down Expand Up @@ -1038,12 +1043,14 @@ Add below flag to train command above
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model"
```

If you run out of CUDA memory, you can try to merge in system RAM with
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with

```bash
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
```

although this will be very slow, and using the config options above are recommended instead.

## Common Errors 🧰

See also the [FAQ's](./docs/faq.md).
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ def do_merge_lora(
safe_serialization = cfg.save_safetensors is True

LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model = model.merge_and_unload(progressbar=True)
model.to(dtype=cfg.torch_dtype)

if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))

Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ def validate_config(cfg):
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
)

if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
41 changes: 37 additions & 4 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Optional, Tuple # noqa: F401
from typing import Any, Optional, Tuple # noqa: F401

import addict
import bitsandbytes as bnb
Expand Down Expand Up @@ -280,8 +280,37 @@ def load_model(

model_kwargs = {}

model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
max_memory = cfg.max_memory
device_map = cfg.device_map

if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
if isinstance(cfg.gpu_memory_limit, int)
else cfg.gpu_memory_limit
)

max_memory = {}
for i in range(torch.cuda.device_count()):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything

if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map, init_empty_weights

with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype=cfg.torch_dtype,
)
# We can discard max_memory now as we have a device map set up for us
max_memory = None

model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype

if is_deepspeed_zero3_enabled():
Expand Down Expand Up @@ -406,7 +435,6 @@ def load_model(
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
del model_kwargs["max_memory"]

model = MambaLMHeadModel.from_pretrained(
base_model,
Expand Down Expand Up @@ -661,10 +689,15 @@ def load_lora(model, cfg, inference=False):

if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=(not inference),
**model_kwargs,
)
else:
model = get_peft_model(model, lora_config)
Expand Down