Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlamaForCausalLM at fp16 w/ FlashAttention gives NAN loss #27212

Closed
2 of 4 tasks
as3eem opened this issue Nov 1, 2023 · 5 comments
Closed
2 of 4 tasks

LlamaForCausalLM at fp16 w/ FlashAttention gives NAN loss #27212

as3eem opened this issue Nov 1, 2023 · 5 comments

Comments

@as3eem
Copy link

as3eem commented Nov 1, 2023

System Info

  • transformers version: 4.34.1
  • Platform: Linux-5.4.0-153-generic-x86_64-with-glibc2.31
  • Python version: 3.11.5
  • Huggingface_hub version: 0.17.3
  • Safetensors version: 0.3.2
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

cc: @SunMarc @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

tokenizerCheckpoint = LlamaTokenizer.from_pretrained(MODEL)
modelCheckpoint = LlamaForCausalLM.from_pretrained(MODEL, use_flash_attention_2=True, torch_dtype=torch.float16)

optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
device = config.DEVICE
model.to(device)

clip_value = 1 # tried other values as well
# print(model.parameters())

model.train()
for epoch in range(config.NUM_EPOCHS):
    total_loss = 0
    for batch_idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)        
        loss = outputs.loss

        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        
        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{config.NUM_EPOCHS}, Average Training Loss: {average_loss}")

Constraint: The task was supposed to be executed in a very vanilla way without using a PEFT wrapper or trainer class so as to customize some parts in the future.

After many surveys:
[Due to fp16 data type, gradients receive value equivalent to -/+ inf and hence nan logits as well as loss]

What didn't work?
torch.cuda.amp.GradScaler()
gradient clamping
reduced learning rate

Expected behavior

receive non-nan values in logits.

@younesbelkada
Copy link
Contributor

Hi @as3eem
I think that pure fp16 training should be avoided, for training in half-precision you should either perform pure bf16 fine-tuning or use automatic mixed precision.

@ArthurZucker
Copy link
Collaborator

Also if you are using padding, some of the nan could be fixed by #27114

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@younesbelkada
Copy link
Contributor

I am pretty sure this is now fixed by #28142 as @pacman100 managed to make it work !
You need to load the model in full-precision and train the model with fp16=True (i.e. with autocast), make sure to use transformers main!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants