Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Finetuning on multi-GPU (FSDP) does not initialize with the foundation model #652

Closed
Jeronymous opened this issue Oct 18, 2023 · 16 comments · Fixed by OpenLLM-France/Lit-Claire#3
Assignees

Comments

@Jeronymous
Copy link

When experimenting the adaptation of Falcon on multi-GPU with finetune/lora.py, we had surprisingly bad results.
After investigation, we realized that we were actually training a randomly initialized model.
(although only checkpointing the LoRA weights, so that model trained from scratch was just lost...).

In other words, the foundation model (Falcon) was not properly loaded.
It seems to be due to the use of fabric.init_module(empty_init=True) at this line:
https://github.com/Lightning-AI/lit-gpt/blob/bf60124fa72a56436c7d4fecc093c7fc48e84433/finetune/lora.py#L128
If we use empty_init=False it trains correctly. I am not sure it's the right fix, though.

@Jeronymous Jeronymous changed the title Finetuning with LORA on multi-GPU (FSDP) does not correctly load the foundation model Finetuning on multi-GPU (FSDP) does not initialize with the foundation model Oct 25, 2023
@Jeronymous
Copy link
Author

I had no answer on that, but it seems to me a quite critical bug 🤔

It does not affect LoRA only : all the finetuning scripts seem to be affected by that.

(just, what makes it worse with LoRA in comparison to "full" finetuning, is that you cannot even recover the model that was trained from scratch, given that only LoRA weights are saved)

@Jeronymous Jeronymous changed the title Finetuning on multi-GPU (FSDP) does not initialize with the foundation model Bug: Finetuning on multi-GPU (FSDP) does not initialize with the foundation model Oct 25, 2023
@carmocca
Copy link
Contributor

carmocca commented Oct 30, 2023

Hi @Jeronymous. The idea in that piece of code is to initialize the model randomly because the pretrained weights are loaded after https://github.com/Lightning-AI/lit-gpt/blob/bf60124fa72a56436c7d4fecc093c7fc48e84433/finetune/lora.py#L147-L148

Can you verify that this is happening for you?

cc @awaelchli

@carmocca carmocca added the question Further information is requested label Oct 30, 2023
@Jeronymous
Copy link
Author

Jeronymous commented Oct 31, 2023

Yes, I checked: load_checkpoint is called after this with fabric.init_module(empty_init=(devices > 1)).
The sequence of relevant instructions are:

    with fabric.init_module(empty_init=(devices > 1)):
        model = GPT(config)
    mark_only_lora_as_trainable(model)

    model = fabric.setup_module(model)

    load_checkpoint(fabric, model, checkpoint_path, strict=False)

    fabric.seed_everything(seed + fabric.global_rank)

We are using :

Maybe what happens (a guess) is that weight tensors are not allocated when empty_init is True, and that it may cause a problem after with lazy loading of the model.

Also, I don't understand why would initialization strategy be different between single-GPU and multi-GPU (the condition "(devices > 1)" that was added a few weeks ago in "fabric.init_module(empty_init=...)").

@Jeronymous
Copy link
Author

@carmocca @awaelchli Here is an evidence that it has a chance of being a bug (not a misuse from us):
#686 (comment) -> others experience problems when finetuning with multi-GPU

@DevasiaThomas
Copy link

DevasiaThomas commented Nov 4, 2023

@Jeronymous I tried FSDP for the first time using the default code, because I was running into OOM with one GPU. The code would not proceed beyond the setting the seed the first time (it was stalled for about 2 hours)
I gave up and tried saw your issue and changed the code - It proceeded after 20 mins but had this output before it crashed due to OOM again (I'll figure this part later)

/opt/conda/envs/pytorch/lib/python3.10/site-packages/lightning/fabric/wrappers.py:176: You are calling the method
`GPT.set_kv_cache()` from outside the model. This will bypass the wrapper from the strategy and result in incorrect
behavior in `.backward()`. You should pass your inputs through `GPT.forward()`.

Another thing I saw was different (idk if this is important), Single GPU runs set the seed to the same number the second time - the multi gpu run did not.

For now, I reduced my sample max_seq_len and switched single GPU.

I'm using A10s, I would upgrade if I could, but AWS won't give me a single A100(only 8 🤣 )

EDIT: The finetuning just randomly crashes and doesn't really output an error message when it does.

@PoodleWang
Copy link

PoodleWang commented Nov 9, 2023

I have the similar issue. If I set the init_weight = True. The loss is aroun 8 or 9. It means the model doesn't load the checkpoint successfully. I use the torchrun to initate the program by the way. I cannot use the python3 xxx.py because my machine set up differently.

@Jeronymous
Copy link
Author

@DevasiaThomas your problems seem to be memory overflow (which is different from the issue opened here, which concerns model initialization on multi-GPU).
When the memory limits are hit, it's hard for a program to fail with a proper error message...

I guess you should use a lower micro_batch_size to solve this.

Multi-GPU shouldn't be of particular help, because it does not reduce the memory for each GPU (just the training should consume samples faster, and note that the actual batch size, i.e. number of samples between 2 model updates, is "batch_size × number of devices").
However FSDP on one GPU seems to consume less memory than vanilla training on one GPU (that might be another option, along with tuning the micro_batch_size).
Good luck!

@Andrei-Aksionov Andrei-Aksionov self-assigned this Nov 9, 2023
@Jeronymous
Copy link
Author

Also here: #689 (comment)

I never understood why this issue was labeled a "question"...

@carmocca carmocca added bug Something isn't working and removed question Further information is requested labels Nov 17, 2023
@carmocca
Copy link
Contributor

Sorry @Jeronymous. We'll look into this asap

@awaelchli
Copy link
Member

@Jeronymous I looked at this again.

On lit-gpt main with lightning 2.1.2 and PyTorch 2.2 nightly I see the finetuning scripts working fine with FSDP (devices=2). The model gets loaded correctly and the loss quickly converges to < 1.0. I don't see anything obviously wrong here.

Then given your info here #652 (comment) I checked out lit-gpt and lightning at this commit and ran again. And the same observations.

I did this with default settings, meaning it uses the stabilityai/stablelm-base-alpha-3b. Please share any changes you've made to lit-gpt and the scripts locally and the checkpoint/model-family you are loading.

(just, what makes it worse with LoRA in comparison to "full" finetuning, is that you cannot even recover the model that was trained from scratch, given that only LoRA weights are saved)

Of course that's still possible. You can always merge the lora weights onto the original checkpoint, that's by design. See generate/lora.py for how that's done. If you feel this is inconvenient, you can just change the line in the script to save the full checkpoint instead of just lora by removing the filter:
https://github.com/Lightning-AI/lit-gpt/blob/c85bf018f92140fab66190a53842519d8c8a7d29/finetune/lora.py#L304

@DevasiaThomas FYI This warning can be ignored, and in the latest version of Lightning it won't appear anymore in this context.

@Andrei-Aksionov
Copy link
Collaborator

(just, what makes it worse with LoRA in comparison to "full" finetuning, is that you cannot even recover the model that was trained from scratch, given that only LoRA weights are saved)

There is also a script to merge LoRA weights with the pre-trained ones.

@Jeronymous
Copy link
Author

Jeronymous commented Nov 27, 2023

Thank you for having a look @awaelchli . I am sorry if you can't reproduce. I'm giving another try with the most recent versions

I haven't modified lit-gpt, and we are finetuning tiiuae/falcon-7b.


Not important, but I think there was a misunderstanding on this:

(just, what makes it worse with LoRA in comparison to "full" finetuning, is that you cannot even recover the model that was trained from scratch, given that only LoRA weights are saved)

I was talking about the bug I faced: it trained from a randomly initialized model (instead of the original checkpoint) which I can't recover. So in this setting, I don't have the original model on which to apply the LoRA weights (to recover the full model that was trained from scratch).
I know about the scripts to merge LoRA and all that.
That was just a side note to say that the impact of the bug is worse when using LoRA (rather than full training).
Sorry for the confusion.

@Jeronymous
Copy link
Author

Jeronymous commented Nov 27, 2023

So I gave it another try. And I continue having the same issue.

I updated to the latest version:

# last commit in lit-gpt from "Sun Nov 26"
lit-gpt @ git+https://github.com/Lightning-AI/lit-gpt@e05fc4a6a39808100cd76aff3d6c26bfae7417be
lightning @ git+https://github.com/Lightning-AI/lightning@6cbe9ceb560d798892bdae9186291acf9bf5d2e3
lightning-cloud==0.5.55
lightning-fabric==2.1.2
lightning-utilities==0.9.0

# last nightly torch build from "Mon Nov 27"
torch==2.2.0.dev20231127+cu121
pytorch-lightning==2.1.2

I retried finetuning (or rather continual pretraining) of tiiuae/falcon-7b with LoRA, on 2 GPUs.

If I use with fabric.init_module(empty_init=(devices > 1)), the first training logs are:

iter 1: {"val_loss": 7.6107, "val_time": 77.49, "peak_vram": "45.33 GB"}
iter 2 step 1: loss 7.8999, iter time: 4872.29ms
iter 3 step 2: loss 7.5855, iter time: 5235.53ms (optimizer.step)
iter 4 step 2: loss 7.7047, iter time: 4878.54ms
iter 5 step 3: loss 7.8596, iter time: 5291.41ms (optimizer.step)

If I use with fabric.init_module(empty_init=False), the first training logs are:

iter 1: {"val_loss": 2.7418, "val_time": 77.95, "peak_vram": "45.33 GB"}
iter 2 step 1: loss 2.7784, iter time: 4941.50ms
iter 3 step 2: loss 2.7446, iter time: 5251.79ms (optimizer.step)
iter 4 step 2: loss 2.6549, iter time: 4910.97ms
iter 5 step 3: loss 2.6943, iter time: 5324.60ms (optimizer.step)

The loss ranges are quite different...

@awaelchli
Copy link
Member

Seeing this "peak_vram": "45.33 GB"} in your output and the fact that optimizer step takes place every other iteration tells me that you still have some modifications in that script and you are not running our script as provided. Could you please make sure? This is very important. Run git stash && git checkout main to avoid any changes mixed in, otherwise what I observe will never match what you observe. This is very important!

I checked again, I re-downloaded the Falcon 7B checkpoint using scripts/download.py file, re-converted it to lit using scripts/convert_hf_checkpoint.py to make sure I don't have a bad one. Then reran scripts/prepare_alpaca.py with that checkpiont dir passed in. Then changed devices=2, precision="bf16-true" in the lora script (because I don't have the 80GB A100s). Then changed the checkpoint dir in lora.py to match the newly downloaded checkpoint. Then ran with both the empty_init=True and empty_init=False configuration and the same loss trajectories are produced.

Please note that for a 7B pretrained checkpoint, the CE loss should be around ~2.5, and for a randomly initialized model it would definitely be larger than 7.

@carmocca
Copy link
Contributor

Also, please share the printed output of these two lines:

@Andrei-Aksionov
Copy link
Collaborator

I also tried to do LoRA fine-tuning.

The latest code from main branch, the latest packages.

I used 4xA10G (the best what I have), bf16-true precision.
Downloaded and converted the model (Falcon-7B) as described in tutorials/download_falcon.md, then ran prepare script for Alpaca dataset with this model.
The only changes in finetune/lora.py are:

  1. device=4
  2. max_iters=10
  3. commented out validate call in train function to make the log more concise.

This is what I got when I tried to fine-tune:

main ~/lit-gpt python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true
{'eval_interval': 100, 'save_interval': 100, 'eval_iters': 100, 'eval_max_new_tokens': 100, 'log_interval': 1, 'devices': 4, 'learning_rate': 0.0003, 'batch_size': 128, 'micro_batch_size': 4, 'gradient_accumulation_iters': 32, 'max_iters': 10, 'weight_decay': 0.01, 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05, 'lora_query': True, 'lora_key': False, 'lora_value': True, 'lora_projection': False, 'lora_mlp': False, 'lora_head': False, 'warmup_steps': 100}
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

[rank: 0] Seed set to 1337
[rank: 1] Seed set to 1337
[rank: 3] Seed set to 1337
[rank: 2] Seed set to 1337
Loading model 'checkpoints/tiiuae/falcon-7b/lit_model.pth' with {'name': 'falcon-7b', 'hf_config': {'org': 'tiiuae', 'name': 'falcon-7b'}, 'block_size': 2048, 'vocab_size': 65024, 'padding_multiple': 512, 'padded_vocab_size': 65024, 'n_layer': 32, 'n_head': 71, 'n_embd': 4544, 'rotary_percentage': 1.0, 'parallel_residual': True, 'bias': False, 'lm_head_bias': False, 'n_query_groups': 1, 'shared_attention_norm': True, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 18176, 'rope_condense_ratio': 1, 'rope_base': 10000, 'r': 8, 'alpha': 16, 'dropout': 0.05, 'to_query': True, 'to_key': False, 'to_value': True, 'to_projection': False, 'to_mlp': False, 'to_head': False, 'head_size': 64, 'rope_n_elem': 64}
Number of trainable parameters: 3,506,176
Number of non trainable parameters: 7,217,189,760
[rank: 3] Seed set to 1340
[rank: 0] Seed set to 1337
[rank: 2] Seed set to 1339
[rank: 1] Seed set to 1338
The longest sequence length in the train data is 1079, the model's maximum sequence length is 1079 and context length is 2048
iter 1 step 0: loss 1.7293, iter time: 10214.60ms
iter 2 step 0: loss 2.5372, iter time: 5275.93ms
iter 3 step 0: loss 2.3912, iter time: 5251.75ms
iter 4 step 0: loss 2.3706, iter time: 5457.99ms
iter 5 step 0: loss 2.1239, iter time: 5294.34ms
iter 6 step 0: loss 2.3765, iter time: 5302.96ms
iter 7 step 0: loss 2.0163, iter time: 5307.21ms
iter 8 step 0: loss 1.8228, iter time: 5372.66ms
iter 9 step 0: loss 2.7029, iter time: 5237.35ms
iter 10 step 0: loss 2.1403, iter time: 5389.18ms
Training time: 58.26s
Memory used: 20.53 GB
Saving LoRA weights to 'out/lora/alpaca/lit_model_lora_finetuned.pth'

The loss values were exactly the same for empty_init=True, empty_init=False and empty_init=(device>1).

@Andrei-Aksionov Andrei-Aksionov removed the bug Something isn't working label Nov 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants