Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VRAM usage regression in 2f2582e #1127

Closed
6 of 8 tasks
adamo1139 opened this issue Jan 16, 2024 · 11 comments
Closed
6 of 8 tasks

VRAM usage regression in 2f2582e #1127

adamo1139 opened this issue Jan 16, 2024 · 11 comments
Labels
bug Something isn't working

Comments

@adamo1139
Copy link

adamo1139 commented Jan 16, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Fine-tuning using configurations that were working in previous version should continue working without OOMs.

Current behaviour

I attempted to replicate my previous fine-tune from a few weeks ago on a slightly modified base model. Same parameter count and base, just merged-in LoRA adapter. I was using the same config file. I was met with OOMs and I was able to trace it back to introduction of commit 2f2582e. Previous commit 0ce1a65 does not exhibit the issue. When moving between versions, I made sure to run pip3 install -e '.[flash-attn,deepspeed]' after moving to a different version. I also tested this with a few various configuration parameters to be confident that I identified the right commit. On commit 0ce1a65 and earlier I am able to do QLoRA SFT Yi-34B finetune on context length of 1400. On commit 2f2582e and later, this is around 600 tokens - running it with sequence length 1400 results in OOM a few steps after starting. I have confirmed this issue exist on 2 various datasets (airoboros 3.1 and aezakmi v2 sharegpt). I am using 24GB VRAM RTX 3090 TI, some VRAM is reserved for DE (XFCE). I also see that, keeping config file constant with sequence length 600, VRAM usage raised by around 1.3GB on commit 2f2582e. In the past, I used that config file and 25-hour training session completed just fine.

Steps to reproduce

  1. Linux environment with Pytorch 2.0.1 + cu118, Ampere GPU with 24GB of VRAM. Flash-attn 2.3.3
  2. Axolotl with commit 2f2582e
  3. Run accelerate launch -m axolotl.cli.train config.yml with supplied config file, adjusting the base model to either Yi-34B 4K llama-fied model or changing max_position_embeddings in config of 200K ctx model to 4K, otherwise it will OOM at loading (known separate issue)

Config yaml

base_model: ./yi-34b-200k-llamafied
base_model_config: ./yi-34b-200k-llamafied
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true

strict: false
datasets:
  - path: jondurbin/airoboros-3.1
    type: sharegpt
    conversation: chatml
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:
sequence_len: 1400
sample_packing: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
  - q_proj
  - v_proj
  - k_proj
  - o_proj
  - gate_proj
  - down_proj
  - up_proj
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-yi-34b-200k-test
pad_to_sequence_len: true
micro_batch_size: 1
gradient_accumulation_steps: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: constant
learning_rate: 0.00005
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
bfloat16: true
flash_optimum: false
gradient_checkpointing: true
early_stopping_patience:
save_safetensors:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
deepspeed:
seed: 42
warmup_steps: 100
eval_steps: 5000000
save_steps: 500
eval_table_size: 
eval_table_max_new_tokens:
debug:
weight_decay:
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<|startoftext|>"
  eos_token: "<|endoftext|>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main/9cd27b2f91111e7ff991cfd464bccc3dc9ffa86a

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@winglian
Copy link
Collaborator

@adamo1139 i know this seems unrelated but can you give this pr a try? #1141 ? You'll need to delete the prepared dataset and rerun the preprocessing before training

@adamo1139
Copy link
Author

adamo1139 commented Jan 19, 2024

@winglian I updated to cbecf3e, confirmed that behaviour is still the same (regression over 0ce1a65), then moved to branch deprecate-max_packed_sequence_len, cleaned out last_run_prepared and did some tests.
It looks like it settles for the same VRAM usage as the current main cbecf3e, but it spikes even higher than cbecf3e around the first step. Due to that, OOMs happen at even lower VRAM.

BUNK TESTS, IGNORE

branch deprecate-max_packed_sequence_len and qlora of Yi-34B-200K on Airoboros 3.1

  • ctx 1400 OOM
  • ctx 600 OOM
  • ctx 400 OOM after spike to 24142 MiB used by axolotl process alone as per nvtop
  • ctx 300 runs, stable VRAM usage around 23570 MiB by axolotl process alone as per nvtop measured around step 50

@ehartford
Copy link
Collaborator

I can't train 20b llama with qLoRA on 4x A100. Oom every time with batch size 1 and 8k context.

How to fix?

@adamo1139
Copy link
Author

@ehartford i suppose it's the llama-fied internlm2-20b, yes?
It has max_position_embeddings size of 32k and this can create spike when training is initiated. Assuming you have reasonably low lora_r already and it didn't help, you can try editing max_position_embeddings to lower value, it will reduce that initial spike in VRAM. That's why I am editing this for yi-34B-200k to 4k - at default 200k value it's guaranteed OOM. At 4k max_position_embeddings yi-34b-200k finetunes come out just fine but if I change it to 16k, output adapter file is breaking the model, so YMMV with results. Are you doing sft or dpo now?

@ehartford
Copy link
Collaborator

yes, exactly. I will try this. Thank you. I am doing sft, dolphin.

@winglian
Copy link
Collaborator

@adamo1139 single gpu or multigpu 24GB?

@adamo1139
Copy link
Author

@winglian single GPU

@winglian
Copy link
Collaborator

Is the model in the posted yml basically https://huggingface.co/01-ai/Yi-34B-200K/tree/main ?

@winglian
Copy link
Collaborator

@adamo1139 I tried #1141 and it has the exact same VRAM usage for me as 0ce1a65, and never spiked. Did you run the axolotl.cli.preprocess first before training?

@winglian
Copy link
Collaborator

the latest version of that PR 799779f should be fix this now.

@adamo1139
Copy link
Author

adamo1139 commented Jan 20, 2024

@winglian I typically don't do preprocessing as a separate step since I didn't think it makes much difference with small 50-100 MB datasets. I will be doing it from now on.

I noticed an issue with my testing of the branch deprecate-max_packed_sequence_len that I did yesterday. I must have switched sample_packing to False sometime during testing, and this produced OOMs even at lower ctx on branch deprecate-max_packed_sequence_len and influenced results that I've reported. Today I noticed this, switched sample_packing back to True, I replicated issue at commit 9cd27b2 - OOM was still present, even with running axolotl.cli.preprocess. Then I confirmed that on 0ce1a65 I didn't have the issue. Finally, I am happy to report that I am also not getting OOMs now on 2ce5c0d, so my issue is resolved and I am closing the discussion.

Thanks for the help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants