Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: better device mapping for large models #918

Merged

Conversation

kallewoof
Copy link
Contributor

@kallewoof kallewoof commented Dec 6, 2023

When a model does not fit completely into the GPU (at 16-bits, if merging with a LoRA), a crash occurs, indicating we need an offload dir. If we hide the GPUs and do it purely in CPU, it works, but now we are not using the GPUs at all.

If we try to do offloading while using GPUs, accelerate ends up trying to offload types that are only supported on the GPUs, which results in the possibly infamous NotImplementedError: Cannot copy out of meta tensor; no data! error.

This pull requests adds two new configuration options. One, gpu_memory_limit is a convenience option that can be used instead of manually setting the max_memory config. (For per-GPU maxes, you need to set it manually though.) It defaults to GBs if an integer, and is assumed to be a proper memory string otherwise (e.g. "123MiB").

The other, lora_on_cpu is a flag which if set will force the PeftModel loading part to be done purely on CPU end. This slows things down, but if the model is taking up too much of the GPU VRAM, the only alternative is to crash and/or buy more GPUs.

Main, attempting to merge a 34b codellama model with a lora, on a 24 GB A10G with 30GB RAM a,d 64GB swap:

[2023-12-06 06:20:36,148] [DEBUG] [axolotl.load_tokenizer:135] [PID:16542] [RANK:0] EOS: 2 / </s>
[2023-12-06 06:20:36,149] [DEBUG] [axolotl.load_tokenizer:136] [PID:16542] [RANK:0] BOS: 1 / <s>
[2023-12-06 06:20:36,149] [DEBUG] [axolotl.load_tokenizer:137] [PID:16542] [RANK:0] PAD: 2 / </s>
[2023-12-06 06:20:36,149] [DEBUG] [axolotl.load_tokenizer:138] [PID:16542] [RANK:0] UNK: 0 / <unk>
[2023-12-06 06:20:36,149] [INFO] [axolotl.common.cli.load_model_and_tokenizer:51] [PID:16542] [RANK:0] loading model and (optionally) peft_config...
[2023-12-06 06:20:36,150] [INFO] [axolotl.load_model:236] [PID:16542] [RANK:0] patching _expand_mask
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:02<00:00,  2.71it/s]
[2023-12-06 06:20:40,509] [INFO] [axolotl.load_model:425] [PID:16542] [RANK:0] GPU memory usage after model load: 17.622GB (+0.017GB cache, +1.173GB misc)
[2023-12-06 06:20:40,513] [INFO] [axolotl.load_model:460] [PID:16542] [RANK:0] converting modules to torch.bfloat16 for flash attention
[2023-12-06 06:20:40,516] [INFO] [axolotl.load_lora:562] [PID:16542] [RANK:0] found linear modules: ['gate_proj', 'q_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj', 'o_proj']
[2023-12-06 06:20:40,516] [DEBUG] [axolotl.load_lora:577] [PID:16542] [RANK:0] Loading pretained PEFT - LoRA
[2023-12-06 06:20:40,649] [WARNING] [auto_gptq.nn_modules.qlinear.qlinear_cuda.<module>:16] [PID:16542] CUDA extension not installed.
[2023-12-06 06:20:40,650] [WARNING] [auto_gptq.nn_modules.qlinear.qlinear_cuda_old.<module>:15] [PID:16542] CUDA extension not installed.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/ec2-user/workspace/axolotl/src/axolotl/cli/merge_lora.py", line 27, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/envs/ai/lib/python3.11/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ai/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ai/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/workspace/axolotl/src/axolotl/cli/merge_lora.py", line 23, in do_cli
    do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
  File "/home/ec2-user/workspace/axolotl/src/axolotl/cli/__init__.py", line 70, in do_merge_lora
    model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/workspace/axolotl/src/axolotl/common/cli.py", line 52, in load_model_and_tokenizer
    model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/workspace/axolotl/src/axolotl/utils/models.py", line 468, in load_model
    model, lora_config = load_adapter(model, cfg, cfg.adapter)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/workspace/axolotl/src/axolotl/utils/models.py", line 503, in load_adapter
    return load_lora(model, cfg, inference=inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/workspace/axolotl/src/axolotl/utils/models.py", line 578, in load_lora
    model = PeftModel.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ai/lib/python3.11/site-packages/peft/peft_model.py", line 306, in from_pretrained
    model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
  File "/opt/conda/envs/ai/lib/python3.11/site-packages/peft/peft_model.py", line 636, in load_adapter
    dispatch_model(
  File "/opt/conda/envs/ai/lib/python3.11/site-packages/accelerate/big_modeling.py", line 368, in dispatch_model
    raise ValueError(
ValueError: We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules need to be offloaded: base_model.model.model.layers.19, base_model.model.model.layers.20, base_model.model.model.layers.21, base_model.model.model.layers.22, base_model.model.model.layers.23, base_model.model.model.layers.24, base_model.model.model.layers.25, base_model.model.model.layers.26, base_model.model.model.layers.27, base_model.model.model.layers.28, base_model.model.model.layers.29, base_model.model.model.layers.30, base_model.model.model.layers.31, base_model.model.model.layers.32, base_model.model.model.layers.33, base_model.model.model.layers.34, base_model.model.model.layers.35, base_model.model.model.layers.36, base_model.model.model.layers.37, base_model.model.model.layers.38, base_model.model.model.layers.39, base_model.model.model.layers.40, base_model.model.model.layers.41, base_model.model.model.layers.42, base_model.model.model.layers.43, base_model.model.model.layers.44, base_model.model.model.layers.45, base_model.model.model.layers.46, base_model.model.model.layers.47, base_model.model.model.norm, base_model.model.lm_head.

This pull request, with

gpu_memory_limit: 20GiB
lora_on_cpu: true
[2023-12-06 04:46:08,700] [INFO] [axolotl.normalize_config:141] [PID:14824] [RANK:0] GPU memory usage baseline: 0.000GB (+0.456GB misc)
[2023-12-06 04:46:08,700] [INFO] [axolotl.common.cli.load_model_and_tokenizer:49] [PID:14824] [RANK:0] loading tokenizer... ./curr-model
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
[2023-12-06 04:46:08,853] [DEBUG] [axolotl.load_tokenizer:135] [PID:14824] [RANK:0] EOS: 2 / </s>
[2023-12-06 04:46:08,853] [DEBUG] [axolotl.load_tokenizer:136] [PID:14824] [RANK:0] BOS: 1 / <s>
[2023-12-06 04:46:08,853] [DEBUG] [axolotl.load_tokenizer:137] [PID:14824] [RANK:0] PAD: 2 / </s>
[2023-12-06 04:46:08,853] [DEBUG] [axolotl.load_tokenizer:138] [PID:14824] [RANK:0] UNK: 0 / <unk>
[2023-12-06 04:46:08,853] [INFO] [axolotl.common.cli.load_model_and_tokenizer:51] [PID:14824] [RANK:0] loading model and (optionally) peft_config...
[2023-12-06 04:46:08,854] [INFO] [axolotl.load_model:236] [PID:14824] [RANK:0] patching _expand_mask
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00,  1.68it/s]
[2023-12-06 04:46:14,187] [INFO] [axolotl.load_model:444] [PID:14824] [RANK:0] GPU memory usage after model load: 19.528GB (+0.017GB cache, +1.173GB misc)
[2023-12-06 04:46:14,191] [INFO] [axolotl.load_model:479] [PID:14824] [RANK:0] converting modules to torch.bfloat16 for flash attention
[2023-12-06 04:46:14,194] [INFO] [axolotl.load_lora:581] [PID:14824] [RANK:0] found linear modules: ['v_proj', 'up_proj', 'down_proj', 'k_proj', 'o_proj', 'q_proj', 'gate_proj']
[2023-12-06 04:46:14,194] [DEBUG] [axolotl.load_lora:598] [PID:14824] [RANK:0] Loading pretained PEFT - LoRA
[2023-12-06 04:46:14,351] [WARNING] [auto_gptq.nn_modules.qlinear.qlinear_cuda.<module>:16] [PID:14824] CUDA extension not installed.
[2023-12-06 04:46:14,352] [WARNING] [auto_gptq.nn_modules.qlinear.qlinear_cuda_old.<module>:15] [PID:14824] CUDA extension not installed.
trainable params: 217,841,664 || all params: 33,961,811,968 || trainable%: 0.6414312175253134
[2023-12-06 04:46:29,104] [INFO] [axolotl.load_model:508] [PID:14824] [RANK:0] GPU memory usage after adapters: 0.000GB ()
[2023-12-06 04:46:29,106] [INFO] [axolotl.scripts.do_merge_lora:73] [PID:14824] [RANK:0] running merge of LoRA with base model
Unloading and merging model: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 678/678 [16:09<00:00,  1.43s/it]
[2023-12-06 05:02:38,873] [INFO] [axolotl.scripts.do_merge_lora:78] [PID:14824] [RANK:0] saving merged model to: qlora-out/merged

(V)RAM stats (from a previous run, so log times will differ):

Peak (V)RAM usage:

$ nvidia-smi; free -h
Wed Dec  6 05:04:49 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   28C    P0    56W / 300W |  21590MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     14824      C   python                          21588MiB |
+-----------------------------------------------------------------------------+
              total        used        free      shared  buff/cache   available
Mem:            30G         30G        298M         13M        507M        339M
Swap:           63G         37G         26G

Post-run (V)RAM usage:

Wed Dec  6 05:38:29 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   28C    P0    60W / 300W |      0MiB / 23028MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
              total        used        free      shared  buff/cache   available
Mem:            30G        353M         29G        168K        998M         30G
Swap:           63G        110M         63G

@winglian
Copy link
Collaborator

winglian commented Dec 7, 2023

Apart from the one minor change, this looks good. I haven't had a chance to test this yet though.

@kallewoof
Copy link
Contributor Author

Thanks. I could have sworn I did that and it failed when I wrote the initial code, but it works now.

kallewoof added a commit to kallewoof/axolotl that referenced this pull request Dec 18, 2023
Comment on lines 9 to 10
gpu_memory_limit: 20GiB
lora_on_cpu: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think adding these for default yaml would be good as a user might not need this and build off this yaml unknowingly.

Suggested change
gpu_memory_limit: 20GiB
lora_on_cpu: true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@kallewoof kallewoof force-pushed the 202312-better-devicemapping branch 4 times, most recently from 2679cbd to 68982f8 Compare December 22, 2023 12:36
@kallewoof
Copy link
Contributor Author

No idea why the test is failing. Feedback would be appreciated.

@NanoCode012
Copy link
Collaborator

NanoCode012 commented Dec 22, 2023

No idea why the test is failing. Feedback would be appreciated.

Error: 023-12-22 12:41:20,563] [ERROR] [axolotl.load_model:490] [PID:1653943] [RANK:0] 'max_memory'
Traceback (most recent call last):
  File "/home/ubuntu/gh/actions-runner/_work/axolotl/axolotl/src/axolotl/utils/models.py", line 433, in load_model
    del model_kwargs["max_memory"]
KeyError: 'max_memory'

Should check if exist before delete?

@kallewoof
Copy link
Contributor Author

No idea why the test is failing. Feedback would be appreciated.

Error: 023-12-22 12:41:20,563] [ERROR] [axolotl.load_model:490] [PID:1653943] [RANK:0] 'max_memory'
Traceback (most recent call last):
  File "/home/ubuntu/gh/actions-runner/_work/axolotl/axolotl/src/axolotl/utils/models.py", line 433, in load_model
    del model_kwargs["max_memory"]
KeyError: 'max_memory'

Should check if exist before delete?

Thanks. This one no longer exists as it's been mapped into a proper device map with this pull request. Removing the del completely should fix.

@kallewoof kallewoof force-pushed the 202312-better-devicemapping branch 2 times, most recently from dbbce93 to cd34680 Compare December 26, 2023 03:48
@KanhnaDT
Copy link

I was banging my head for hours about this issue, then i stumble upon this PR, i try it, and it works flawlessly!

Was trying to merge a 13b base model with fine tuned qlora adapters on 24GB VRAM GPU (and 32GB system RAM) EC2 instance.
It kept failing with the error:
We need an `offload_dir` to dispatch this model
even with CUDA_VISIBLE_DEVICES=

so basically this PR allows the merging process to use as much GPU as possible and allows spilling to RAM if needed! very convenient when merging large models without enough VRAM resources!!

Thank you very much @kallewoof !

@kallewoof kallewoof mentioned this pull request Dec 27, 2023
8 tasks
Copy link
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I apologize for the late review. I think this is very nice as seen in other's posts.

Can I ask for one last thing, which is to add a validation check in def validate_config(cfg): within utils/config.py to make sure that if both max_memory and gpu_memory_limit is passed, to raise a ValueError and warn user to use one of them only?

Afterwards, I think this is good to merge.

@kallewoof
Copy link
Contributor Author

@NanoCode012 Thanks for the review. Added to validation checks.

@NanoCode012 NanoCode012 merged commit bdfefaf into axolotl-ai-cloud:main Jan 5, 2024
4 checks passed
@NanoCode012
Copy link
Collaborator

Thank you for the PR!

@kallewoof kallewoof deleted the 202312-better-devicemapping branch January 5, 2024 13:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants