Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero loss and nan grad_norm when Flash Attention is enabled #1706

Open
6 of 8 tasks
fgdfgfthgr-fox opened this issue Jun 13, 2024 · 0 comments
Open
6 of 8 tasks

Zero loss and nan grad_norm when Flash Attention is enabled #1706

fgdfgfthgr-fox opened this issue Jun 13, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@fgdfgfthgr-fox
Copy link

fgdfgfthgr-fox commented Jun 13, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I expect similar loss and grad_norm when training a model with the same setting regardless whether flash attention is enabled or not.

Current behaviour

Currently, during training steps (right from the start), I can see messages of
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 6.545084971874738e-06, 'epoch': 0.4}
for few steps, before a

  File "/home/huada524/ondemand/data/sys/myjobs/projects/default/1/huada524-prune-env/lib/python3.10/site-packages/flash_attn/bert_padding.py", line 110, in unpad_input
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
RuntimeError: CUDA error: an illegal memory access was encountered

error appear and the training stops.

However, if flash attention is disabled with flash_attention: false, then the network trains normally.
{'loss': 3.0972, 'grad_norm': 0.76171875, 'learning_rate': 3.4549150281252635e-06, 'epoch': 0.6}

Steps to reproduce

  1. I Installed my axolotl on a remote cluster with 3x L40 graphic cards with slurm, using the following script:
module load python
module load cuda

echo "Setting up python venv..."
python -m venv venv
source venv/bin/activate
python -m pip install --upgrade pip
pip install -U wheel
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 -I
pip install ninja
export TORCH_CUDA_ARCH_LIST="8.6;8.9"
export CUDA_VISIBLE_DEVICES=2
export LD_LIBRARY_PATH=/home/huada524/ondemand/data/sys/myjobs/projects/default/1/venv/lib64/python3.10/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
pip install -v -U "git+https://github.com/facebookresearch/xformers.git@main#egg=xformers"

cd axolotl
git pull
pip install packaging
pip install -e '.[flash-attn,deepspeed]'
# I manually disabled xformers installation from axolotl/requirements.txt so it won't attempt to override the one I just compiled with.
# I also have to apply this patch https://github.com/microsoft/DeepSpeed/issues/5603 to make sure axolotl would launch
cd ..
  1. I started the training using the script below:
module load python
module load cuda

source venv/bin/activate

export CUDA_VISIBLE_DEVICES=2

export WANDB_API_KEY=xxxxxxx
export LD_LIBRARY_PATH=/home/huada524/ondemand/data/sys/myjobs/projects/default/1/venv/lib64/python3.10/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH

accelerate launch -m axolotl.cli.train config_llama3_40B_dora.yaml

Note the model I am training the lora with is meta's llama-3-70B model with some of its layers removed.
GPU are running on a CUDA version of 12.5, while the loaded module is 12.3.


**** Axolotl Dependency Versions *****
accelerate: 0.30.1
peft: 0.11.1
transformers: 4.41.1
trl: 0.8.7.dev0
torch: 2.4.0.dev20240610+cu124
bitsandbytes: 0.43.1


Config yaml

base_model: /home/huada524/ondemand/data/sys/myjobs/projects/default/1/PruneMe/slice_with_mergekit/merged
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

pretraining_dataset:
  - path: HuggingFaceFW/fineweb-edu
    name: default
    type: completion
val_set_size: 0.05
max_steps: 10
output_dir: ./outputs/out

adapter: qlora
lora_r: 8
lora_alpha: 4
lora_dropout: 0.0
lora_target_linear: true
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
#  - lm_head
peft_use_dora: false
lora_model_dir:

sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true

wandb_mode: online
wandb_project: Creating Llama-3-40B
wandb_entity:
wandb_watch:
wandb_name: Experimental_Runs

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 1e-5
#loraplus_lr_embedding: 1e-6

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 0
evals_per_epoch: 1
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
#  - full_shard
#  - auto_wrap
fsdp_config:
#  fsdp_limit_all_gathers: true
#  fsdp_sync_module_states: true
#  fsdp_offload_params: true
#  fsdp_use_orig_params: false
#  fsdp_cpu_ram_efficient_loading: true
#  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
#  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
#  fsdp_state_dict_type: FULL_STATE_DICT
#  fsdp_sharding_strategy: FULL_SHARD
special_tokens:
  pad_token: <|end_of_text|>

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

5783839

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@fgdfgfthgr-fox fgdfgfthgr-fox added the bug Something isn't working label Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant