Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3 8b OOM with GaLore on 2x A100s (Mistral 7b is fine?) #1641

Open
6 of 8 tasks
e-p-armstrong opened this issue May 19, 2024 · 5 comments
Open
6 of 8 tasks

Llama 3 8b OOM with GaLore on 2x A100s (Mistral 7b is fine?) #1641

e-p-armstrong opened this issue May 19, 2024 · 5 comments
Labels
bug Something isn't working possibly_solved

Comments

@e-p-armstrong
Copy link

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Llama 3 8b, with only one billion more parameters, should presumably be able to GaLore train at least on 2x A100s (Mistral v0.2 can train on 1x A100).

Current behaviour

Llama 3 8b OOMs immediately when being tuned with GaLore at 8k sequence length even if obscene amounts of compute are thrown at it.

Steps to reproduce

  1. Rent 2x A100s on Vast.ai or any other provider
  2. Run a training run with the provided config (can use any pretraining data and finetuning data as stand-ins for unavailable datasets)
  3. Observe near-immediate OOM.

Config yaml

base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: json
    data_files: pretraining_vision.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_VISION.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_rag_VISION.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: pretraining_wiki.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_WIKI.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_no_rag_WIKI.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: pretraining_api.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_API.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_no_rag_API.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: pretraining_docs.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_DOCS.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_no_rag_DOCS.jsonl
    ds_type: json
    type: sharegpt
dataset_prepared_path: last_run_prepared
output_dir: ./verus-out

sequence_len: 8100
sample_packing: true
pad_to_sequence_len: true

wandb_project: verus-llama-experiment-2
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 6
eval_batch_size: 6
num_epochs: 5
optimizer: galore_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0000035
cosine_min_lr_ratio: 0
weight_decay: 0 # no weight decay to maximize fact memorization (thanks cgato!)
# adamw hyperparams
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 0.00000001
# Gradient clipping max norm
max_grad_norm: 1.0
noisy_embedding_alpha: 0 # no noisy embedding to ensure maximal memorization 

optim_args:
# For Galore Optimizers the following optim_args are available
    rank: 256 # type: int
    update_proj_gap: 200  # type: int
    scale: 0.25  # type: float
    proj_type: "std" # type: str, default = std

optim_target_modules: 
  - mlp
  - self_attn
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
auto_resume_from_checkpoints: false
eval_steps: 10
saves_per_epoch: 1
eval_sample_packing: false
save_total_limit: 2
debug:
deepspeed: deepspeed_configs/zero2.json
special_tokens:
  pad_token: "<|end_of_text|>"

Possible solution

I was able to finetune Llama 3 8b instruct by reducing the sequence length to around 2000 tokens. However unless I'm missing something big that makes Llama very inefficient with GaLore, presumably I should be able to full finetune at 8000 sequence length if I use 2 whole A100s.

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10.14

axolotl branch-commit

whatever the official docker image is on.

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@e-p-armstrong e-p-armstrong added the bug Something isn't working label May 19, 2024
@e-p-armstrong
Copy link
Author

Edit: it also OOMs on 4x A100s, something is definitely wrong here.

image

@winglian
Copy link
Collaborator

How much VRAM does mistral 7B use w/ galore on your setup? Keep in mind that llama-3 has a much larger embeddings layer than Mistral (128k vs 32k) which significantly increases VRAM use.

I would start by decreasing the batch sizes and using the unsloth gradient checkpointing.

micro_batch_size: 4
eval_batch_size: 4
gradient_checkpointing: unsloth

@e-p-armstrong
Copy link
Author

e-p-armstrong commented May 20, 2024

VRAM usage w/ Mistral 7b seems to be about 58 gigabytes only, on 1x A100:
image

As for Llama 3, I decreased the batch size to 1,set gradient checkpointing to unsloth, and increased gradient accumulation steps to 6. And rented out 8x A6000s. Still OOMs even with 383.9 GB of VRAM available for finetuning an 8b model. This can't be right, can it?

image
image

Edit: this may be related to issue #1448

Edit 2: Was using the wrong config -- it does not OOM 8x A6000s. However it IS using 273 GB of VRAM. Seems a bit high?

@e-p-armstrong
Copy link
Author

e-p-armstrong commented May 25, 2024

Update: Getting 343GB usage when finetuning llama 3 8b. This has got to be wrong. No idea how to even begin to address this, however.
image

@Abhis-123
Copy link

@e-p-armstrong did you got any workaround

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working possibly_solved
Projects
None yet
Development

No branches or pull requests

3 participants