Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fine tuning the updated Phi-2 with flash-attn-2 produces very high loss > 2 #28488

Closed
4 tasks
abacaj opened this issue Jan 13, 2024 · 50 comments · Fixed by #28537
Closed
4 tasks

fine tuning the updated Phi-2 with flash-attn-2 produces very high loss > 2 #28488

abacaj opened this issue Jan 13, 2024 · 50 comments · Fixed by #28537

Comments

@abacaj
Copy link

abacaj commented Jan 13, 2024

System Info

The updated code of phi-2 produces a high loss, I have tried fp16, bf16, deepspeed and fsdp the result is the same -> loss starts at 2 and keeps going higher. Setting use_flash_attention_2=False fixes this or using the old phi-2 modeling file.

torch==2.1.2
flash-attn==2.4.2
transformers==4.37.0.dev0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Fine-tune the updated phi-2 model using transformers trainer

Expected behavior

Loss go down

@NicolasMejiaPetit
Copy link

I experienced the same thing! Over 3 epochs same set up just updated code and flash attention, the loss went from 6 to 2. And on the old code without flash attention it was .60 to ~.29 . Very strange.

@amyeroberts
Copy link
Collaborator

cc @younesbelkada @ArthurZucker

@younesbelkada
Copy link
Contributor

Hi @abacaj, as per @pacman100 guidelines in #28142 / #28142 (comment) you need to make sure to load your model in full-precision and train with autocast (bf16=True). Also can you share more insights on how you train your model? (do you load the model in bf16/fp16, do you use PEFT, packing, etc.) ?

@abacaj
Copy link
Author

abacaj commented Jan 15, 2024

Hi @younesbelkada, this is a full fine tune using HF trainer. Padding only. Model is loaded in bf16. I try loading in "fp32" but get error:

ValueError: Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed torch.float32, this might lead to unexpected behaviour.
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        config=config,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.float32,
        cache_dir=training_args.cache_dir,
    )

@younesbelkada
Copy link
Contributor

Ok thanks @abacaj for getting back ! I think you get that error because the patch #28142 has not been released on pypi - can you try to build transformers from source?

pip install -U git+https://github.com/huggingface/transformers.git

That should hopefully solve it, let me know if you face into more issues!

@abacaj
Copy link
Author

abacaj commented Jan 15, 2024

Ok so I remove the explicit torch_dtype following the comments in your link. The loss is still very high with flash-attn-2 using phi-2 model

@younesbelkada
Copy link
Contributor

@abacaj which padding side are you using for training?

@abacaj
Copy link
Author

abacaj commented Jan 15, 2024

I use padding_side="left". Here is how the loss goes with and without FA2 (green line has FA2) using phi-2:

image

@abacaj
Copy link
Author

abacaj commented Jan 15, 2024

FWIW changing padding side doesn't do anything to the loss, it's the same

@younesbelkada
Copy link
Contributor

I see, as a sanity check, can you share your TrainingArguments ?

@abacaj
Copy link
Author

abacaj commented Jan 15, 2024

adafactor=False,
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=True,
bf16_full_eval=False,
cache_dir=None,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=src/configs/deepspeed_2_config.json,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=0.0,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
include_num_input_tokens_seen=False,
include_tokens_per_second=False,
inference_length=2048,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=5e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=1,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=checkpoints/results/2k-2k-dynamic-5e-5/runs/Jan15_13-42-15_sgpu,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=1.0,
logging_strategy=steps,
lr_scheduler_kwargs={},
lr_scheduler_type=cosine,
max_grad_norm=1.0,
max_steps=-1,
metric_for_best_model=None,
model_max_position_embeddings=2048,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_torch,
optim_args=None,
output_dir=checkpoints/results/2k-2k-dynamic-5e-5,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=4,
per_device_train_batch_size=4,
prediction_loss_only=False,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=['tensorboard'],
resume_from_checkpoint=None,
rope_scaling_factor=1.0,
rope_scaling_type=dynamic,
run_name=checkpoints/results/2k-2k-dynamic-5e-5,
save_on_each_node=False,
save_only_model=False,
save_safetensors=True,
save_steps=100.0,
save_strategy=epoch,
save_total_limit=None,
seed=70,
skip_memory_metrics=True,
split_batches=False,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.02,
warmup_steps=0,
weight_decay=0.1

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Jan 15, 2024

During my testing, I used bf16, trust remote code, no gradient ckpt, for SFT, with flshattn. The resulting model was terrible, I knew off the of (6 to 2)loss it wasn’t going to preform, but during testing it was worse than expected, very mangled answers. However when I trained the model, same arguments; just using the old phi repo code, and no flshattnt I got a great model. The loss from .6 to .29. Both were full fine tunes. Flash attention is critical for 24gb cards as without it it’s training off shared memory. I can help out more with testing when it’s done training in ~30 hours off shared mem 😭. The script I used is on #28381 . (Keep in mind the script doesn’t reflect me using bf16, however both times I trained the model I did have compute dtype set to bf16.)

@gugarosa
Copy link
Contributor

Hello everyone!

Could you all please test using the latest revision on microsoft/phi-2 and report the results? We might have found the issue.

Regards,
Gustavo.

@abacaj
Copy link
Author

abacaj commented Jan 18, 2024

FWIW - the model still comes out significantly worse using FA2. If anyone wants to fine-tune this model, I recommend you use it without FA2 currently. Running code benchmarks with FA2 < 50% on heval. Without FA2 (and all other hparams are identical, including seed) > 60% heval.

@survivi
Copy link

survivi commented Jan 18, 2024

The first graph is a comparison between using and not using flash attention 2. It seems that the loss doesn't change much with fa2 (yellowish curve).
截屏2024-01-18 16 43 57

@gugarosa
Copy link
Contributor

@abacaj could you please provide a minimal snippet to reproduce your fine-tuning?

We want to investigate it further more and attempt to find the root of the problem. We are doing a line-by-line comparison between the new model's code and the previous one.

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Jan 18, 2024

FWIW - the model still comes out significantly worse using FA2. If anyone wants to fine-tune this model, I recommend you use it without FA2 currently. Running code benchmarks with FA2 < 50% on heval. Without FA2 (and all other hparams are identical, including seed) > 60% heval.

I second this, just woke up and checked training after 3 epochs with FA2 I went from .54 to ~.40, meanwhile, with no FA2 I went from .60 to .30. Both full fine tunes. I’m gonna train the fa2 check point on an additional epoch to see if it gives me the same loss as with out FA2. Or to see if it over fits.

EDIT:
The loss is off to a terrible start. Went as low as .39 then up to as high as .49.( It’s only at .07 of an epoch. But i’m training a check point that has trained on this exact dataset already for 3 epochs.) Significantly better than before with the soft max scaling issues, but there is still something up.

IMG_2363

The loss is acting quite random, in comparison to no FA2. Where the loss consistently went down.

SECOND UPDATE: I restarted the training with the same checkpoint and upped the learning rate by a order of 1, so from 2e5 to 2e6 and now the loss is more consistent, confusing why this hyper parameter differs in training when using fa2 and not using fa2.
IMG_2364

Not perfect but better.

THIRD UPDATE: I tried retraining the base model with fa2 and the loss isnt going anywhere. After 1.5 epochs. Its almost as if the weights aren’t being updated at all, and if so very marginally. Just consistently staying between .5 and .4 but random at every logging step.

@geronimi73
Copy link

sorry to comment on this closed issue but I still have issues with FA2

  1. loss is different with FA2 compared to without
  2. loss is different even between two FA2 runs (used set_seed. doesn't happen without FA2 - loss always exactly the same)

W B Chart 26_01_2024, 22_43_04

@abacaj
Copy link
Author

abacaj commented Jan 27, 2024

I do another two runs using FA2 and without FA2 (blue line). Testing the models out using vLLM, the model without FA2 (blue line) scores 58% on humaneval, the model with FA2 scores 49%. I basically stopped using FA2 because the model comes out so much worse and I can't pinpoint why (the runs are identical with exception of use_flash_2)

image

@NicolasMejiaPetit
Copy link

Hi @ArthurZucker,
Could you reconsider opening this issue again? I think it’s worth opening, as training with flash attention on phi-2 is still not viable. The performance gains are almost essential though. I appreciate it thank you!

@younesbelkada younesbelkada reopened this Feb 1, 2024
@farhang87
Copy link

Just wanted to acknowledge I have the same issue with using Fast Attention 2 with phi-2, the training loss hardly decreases with FA2 turned on, and works pretty well with it turned off.

@LinB203
Copy link

LinB203 commented Feb 7, 2024

same question...

@geronimi73
Copy link

We want to investigate it further more and attempt to find the root of the problem. We are doing a line-by-line comparison between the new model's code and the previous one.

@gugarosa is there any update on fixing FA2 for this amazing model?

@ArthurZucker
Copy link
Collaborator

There is a PR that should fix it but is hard to merge #28673

@akjindal53244
Copy link

akjindal53244 commented Feb 13, 2024

I am also seeing similar issue where loss is trending downwards but quite unstable and it seems to learn very slowly. I am running full fine-tuning of latest Phi2 model on my dataset.
Screenshot 2024-02-12 at 9 22 57 PM

@ArthurZucker I just started another run after reinstalling transformers with changes from #28673 to see if it fixes the issue (still using FA2). will post loss curve in next few hours.

[Incorrect] Update-1: loss curve after reinstalling transformers with changes from [#28673]. Looks like there is no change..
W B Chart 2_13_2024, 8_30_06 AM

Update-2: Looks like my new transformer installation didn't include changes from #28673 so essentially both plot should be same. I tried reinstalling transformers again with PR changes and now training is failing:

  File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/minimalist/work/projects/transformers/src/transformers/models/phi/modeling_phi.py", line 318, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Feb 14, 2024

Update, it was a torch error, its training now, but the loss is the same as before, I lowered my dataset to 1k examples over 3 epochs with a lr of 2e6 and still the loss is random. Never consistently going down.

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Feb 21, 2024

How are you guys testing this? It does seem to matter when doing a full fine tune, and a lora fine tune. Using FA2 I could never get the loss to even be consistent with a full fine tune ( with SFT). Right now I am doing a DPO of phi2 with QLORA, and the loss is not only consistent, it’s consistently going down; from .69 to .27 at just a single epoch.

I have not tried SFT with a lora, but maybe if we wanna use FA2 its just better to stick with using lora.

@younesbelkada
Copy link
Contributor

younesbelkada commented Feb 21, 2024

hi there, now that SDPA has been merged #29108 you can use FA-2 through pytorch interface:

0- Install pytorch 2.2
1- make sure to load SDPA attention by passing attn_implementation="sdpa" in from_pretrained
2- Force-dispatch the SDPA kernel to use FA2 as follows:

- trainer.train()
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+    trainer.train()

Perhaps this will fix all instability issues with respect to FA2 !

@NicolasMejiaPetit
Copy link

@younesbelkada Hey! Thanks for your response!(before starting my training run) I got pytorch2.2, and I pulled the latest commits from transformers and installed from source. I’m using the DPO.py, from TRL, and I saw the commit, so I tried to pass “—attn_implementation SPDA” but it gave me a SPDA not currently supported error, I wish I still had the error up, I’ll try it out again, once my training run ends in a little less than an hour. However I had only tried and pass it as a flag, not how you are just now telling me.

@younesbelkada
Copy link
Contributor

Hi @NickWithBotronics !
You need to use transformers from source: pip install -U git+https://github.com/huggingface/transformers

@NicolasMejiaPetit
Copy link

@younesbelkada sorry, I don’t think I said what I meant to say properly. I had used ‘git pull’ to get the latest commits onto where I have my transformer saved, and I installed with ‘pip uninstall transformers’ , and then ‘pip install .’ I did that after I saw the SPDA commit. So I thought I did it right, however I’ll try it out again with ‘ pip install -U git+https://github.com/huggingface/transformers’

And try adding the

Python’’’

  • with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    trainer.train()‘’’
    Line

@younesbelkada
Copy link
Contributor

Ah yes sorry, I overead that part, can you share your full snippet?

@NicolasMejiaPetit
Copy link

I’m currently using this script, and when I tried using spda I used the argument “—attn_implementation spda”
https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py”

@younesbelkada
Copy link
Contributor

@NickWithBotronics I think it failed because you passed SDPA in full capitals (reading it from #28488 (comment)) can you try again with sdpa instead of SDPA?

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Feb 21, 2024

CMD.txt
CMD-AFTER-LINE_EDIT.txt

Actually it ran after the line edit, without using the --attn flag, idk if its using it as its using the same memory as before, so its going into my shared memory; meanwhile, with FA2 it fits with dpo with my current arguments.

I think i tried a variation of everything, using "--attn_implementation flash_attention_2" does work though
@younesbelkada

@younesbelkada
Copy link
Contributor

Hi @NickWithBotronics - thanks for getting back!
I just tried to run on transformers main:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa")
print(model)

And I can confirm that I am able to succesfully load phi-2 with SDPA. Can you try that script? If it works on your end I'll check out the dpo script and see if it fails there for me 🙏

@NicolasMejiaPetit
Copy link

@younesbelkada I was able to test out the script you provided, and it worked. (I'm assuming it has something to do with the DPO trainer.)
CMDlog.txt

@younesbelkada
Copy link
Contributor

OK thanks for getting back , I will try that on my end and get back to you

@abacaj
Copy link
Author

abacaj commented Mar 1, 2024

Hi @NickWithBotronics - thanks for getting back! I just tried to run on transformers main:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa")
print(model)

And I can confirm that I am able to succesfully load phi-2 with SDPA. Can you try that script? If it works on your end I'll check out the dpo script and see if it fails there for me 🙏

Does this work? I'm on main and have torch==2.2.0

ValueError: PhiForCausalLM does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`

@younesbelkada
Copy link
Contributor

Hi @abacaj
I just ran that script and it works on my end on transformers main - can you double check?

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2560, out_features=51200, bias=True)
)

@abacaj
Copy link
Author

abacaj commented Mar 9, 2024

Hi @abacaj I just ran that script and it works on my end on transformers main - can you double check?

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2560, out_features=51200, bias=True)
)

I tried again same error:

ValueError: PhiForCausalLM does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")

I'm using:
transformers 4.39.0.dev0

@abacaj
Copy link
Author

abacaj commented Mar 9, 2024

Ah I see issue it seems to throw that error when I create model with AutoConfig, it is working now will give it a try

@abacaj
Copy link
Author

abacaj commented Mar 9, 2024

Got this error during a training step:

RuntimeError: No available kernel. Aborting execution.

@geronimi73
Copy link

Got this error during a training step:

I dont get this error (or any error) but the loss is completely different than with eager attention still

@Becomebright
Copy link

Has there been any update?
I've discovered that SDPA + BF16 results in a significantly higher (~10x) gradient norm and slower convergence.
Disabling BF16 or switching to FP16 resolves the issue.

@YiqunChen1999
Copy link

Has there been any update? I've discovered that SDPA + BF16 results in a significantly higher (~10x) gradient norm and slower convergence. Disabling BF16 or switching to FP16 resolves the issue.

Same here.

@hackyon
Copy link
Contributor

hackyon commented Mar 24, 2024

@Becomebright Thanks for looking into this. I'm curious, you're saying it actually works fine (no high loss) for you with SDPA + FP16?

Reading through #28673, I'm not sure there's a way to get around the attention overflow issues. It seems Phi-2 might only be compatible with FP32 (short of retraining the model), and the only way to use the model for now is to run it in FP32. (Or maybe someone can try running it on FP32 with fine-tuning on FP16/BF16 LoRa to see if that works?)

@Becomebright
Copy link

Becomebright commented Mar 24, 2024

@hackyon I'm not certain whether FP16 is effective since I've only tried over-fitting one sample.
With FP16, the grad_norm is 0 for the first two iterations, but it gets normal afterward, and the loss converges as quickly as with FP32. Conversely, using BF16 results in a significantly larger grad_norm, and the loss decreases slowly.
Will try finetuning with FP16.

Update: Finetuning Phi-2 using DeepSpeed Zero2, SDPA, and FP16 has encountered an overflow issue: Current loss scale already at minimum - cannot decrease scale anymore.

@farhang87
Copy link

farhang87 commented Apr 15, 2024

Just wanted to acknowledge I have the same issue with using Fast Attention 2 with phi-2, the training loss hardly decreases with FA2 turned on, and works pretty well with it turned off.

Update: I just tried finetuning again on Phi2 with the same dataset and settings, once with and without FA2. Can confirm the differences now were marginally, while FA2 was a lot faster. Here are the results: FA2: TrainOutput(global_step=356, training_loss=1.24568430895216, metrics={'train_runtime': 3220.3242, 'train_samples_per_second': 3.538, 'train_steps_per_second': 0.111, 'total_flos': 1.862385933484032e+17, 'train_loss': 1.24568430895216, 'epoch': 1.0}) Without FA2: TrainOutput(global_step=356, training_loss=1.1397285970409265, metrics={'train_runtime': 4415.1268, 'train_samples_per_second': 2.581, 'train_steps_per_second': 0.081, 'total_flos': 1.862385933484032e+17, 'train_loss': 1.1397285970409265, 'epoch': 1.0}).

The only thing is, I finetuned Phi-2 for a summarisation task. I also compared the results with and without FA2, and here are the Rouge differences (higher is better). Without FA2: Rouge1:66.5, RougeL: 49.3, WITH FA2: Rouge1:62.9 RougeL: 54.8. Does this mean with FA2 one has to finetune for a longer period of time to get the same results. Which doesn't really stroke with the whole idea of getting faster training with FA2. What am I missing here?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.