Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA device error with llama2_chat strategy #568

Closed
6 of 8 tasks
kaldeberger opened this issue Sep 14, 2023 · 15 comments
Closed
6 of 8 tasks

CUDA device error with llama2_chat strategy #568

kaldeberger opened this issue Sep 14, 2023 · 15 comments
Labels
bug Something isn't working

Comments

@kaldeberger
Copy link

kaldeberger commented Sep 14, 2023

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I am trying to finetune code-llama on runpod with the provided docker image using this command: accelerate launch scripts/finetune.py.

When using the config from examples/code-llama/13B/qlora.yml with a dataset from local file-system in llama2_chat format should work.

Note:
Other formats, e.g. alpaca or sharegpt:chat, do work fine with this config. The problem seems to be with the llama2 prompt strategy.

Current behaviour

After loading the model shards there are many of these error messages:

indexSelectLargeIndex: block: [0,0,0], thread: [64,0,0] Assertion srcIndex < srcSelectDimSiz e failed.

followed by

CUDA Error: device-side assert triggered /tmp/pip-install-3n5798ar/dropout-layer-norm_1323beba825f4e58852d39754f678e64/csrc/layer_norm/ln_fwd_kernels.cuh 236

The finetune.py process terminates.

Steps to reproduce

  1. Create a sample dataset in llama2_chat format as jsonl file, e.g. {"conversations": [{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am your chat assistant."}]}
  2. use the provided qlora.yml config from examples/code-llama/13B/ and add your sample dataset:
datasets:
  - path: dataset.jsonl
    ds_type: json
    type: llama2_chat
  1. start with accelerate launch scripts/finetune.py qlora.yml
  2. wait for tokenization and loading of shards to complete

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Note: I am using the runpod template via the direct link included in the Readme.

Python Version

from the runpod template (docker image)

axolotl branch-commit

main

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@kaldeberger kaldeberger added the bug Something isn't working label Sep 14, 2023
@kaldeberger
Copy link
Author

I just tried it again with the latest commits from main. The issue still occurs. Here is the full error message:

[2023-09-14 23:42:26,287] [INFO] [axolotl.utils.dataloader._len_est:264] [PID:921] [RANK:0] packing_efficiency_estimate: 1.0 total_num_tokens per device: 14143488
  0%|                                                                                                                                                                           | 0/530 [00:00<?, ?it/s][2023-09-14 23:42:26,652] [INFO] [axolotl.utils.dataloader._len_est:264] [PID:921] [RANK:0] packing_efficiency_estimate: 1.0 total_num_tokens per device: 14143488
[2023-09-14 23:42:26,652] [INFO] [axolotl.utils.dataloader.generate_batches:181] [PID:921] [RANK:0] generating packed batches
[2023-09-14 23:42:26,657] [INFO] [axolotl.utils.dataloader.generate_batches:187] [PID:921] [RANK:0] 99b24bb90d0ab3630c3d18fdada684c3498a32bfe5e9777c1cab8fc8fcf335bc
[2023-09-14 23:42:26,664] [INFO] [axolotl.utils.dataloader._len_est:264] [PID:921] [RANK:0] packing_efficiency_estimate: 1.0 total_num_tokens per device: 14143488
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [102,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [103,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [104,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [105,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [106,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [107,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [108,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [109,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [110,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [111,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [112,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [113,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [114,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [115,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [116,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [117,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [118,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [119,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [120,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [121,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [122,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [123,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [342,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [6,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [7,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [8,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [9,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [10,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [11,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [12,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [13,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [14,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [15,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [16,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [17,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [18,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [19,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [20,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [21,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [22,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [23,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [24,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [25,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [26,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [27,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [28,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [29,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [276,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
    fire.Fire(do_cli)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
   component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
    train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
  File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1575, in train
    return inner_training_loop(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1875, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 2740, in training_step
    loss = self.compute_loss(model, inputs)
  File "/workspace/axolotl/src/axolotl/utils/trainer.py", line 296, in compute_loss
    return super().compute_loss(model, inputs, return_outputs=return_outputs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 2765, in compute_loss
    outputs = model(**inputs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/utils/operations.py", line 636, in forward
    return model_forward(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/utils/operations.py", line 624, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/peft/peft_model.py", line 946, in forward
    return self.base_model(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 94, in forward
    return self.model.forward(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 820, in forward
    outputs = self.model(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/workspace/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 517, in llama_model_forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/workspace/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 513, in custom_forward
    return module(*inputs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/workspace/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 607, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/workspace/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 141, in flashattn_forward
    query_states = self.q_proj(hidden_states)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/peft/tuners/lora/bnb.py", line 256, in forward
    result = super().forward(x)
   result = super().forward(x)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/bitsandbytes-0.41.1-py3.10.egg/bitsandbytes/nn/modules.py", line 248, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/bitsandbytes-0.41.1-py3.10.egg/bitsandbytes/autograd/_functions.py", line 579, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/bitsandbytes-0.41.1-py3.10.egg/bitsandbytes/autograd/_functions.py", line 516, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
  0%|                                                                                                                                                                           | 0/530 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/root/miniconda3/envs/py3.10/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 986, in launch_command
    simple_launcher(args)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 628, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/root/miniconda3/envs/py3.10/bin/python', 'scripts/finetune.py', 'qlora-codellama13b.yml']' returned non-zero exit status 1.

@vibhorag101
Copy link

I am having the very same issue. ``RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZEDwhen calling cublasCreate(handle).
The logs shared above by @kaldeberger are also the same.
I don't think there is any issue with the environment. I have tried multiple docker images, python 3.9 and 3.10 versions as well as older docker images from the past.
This error does not occur on any other prompting strategies such as sharegpt:chat. So it has to be some bug in the code.
My dataset looks something like this.

    {
        "id": "identity_0",
        "conversations": [
            {
                "from": "human",
                "value": "I've been feeling so sad and overwhelmed lately. Work has become such a massive source of stress for me."
            },
            {
                "from": "gpt",
                "value": "Hey there, I'm here to listen and support you. It sounds like work has been really challenging lately. Can you tell me more about what's been going on?"
            },
         ]
      }
  }

@dimichgh
Copy link

The same issue for me

@dimichgh
Copy link

Given that other reports on similar errors are related to mismatch in index. I found one place that was adding pad token to tokenizer in llama_chat which was not accounted in model or maybe too late to resize the model.
If you remove it, it will work, but model and tokenizer better be extended with this token before training
The line I am talking about is
https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/prompt_strategies/llama2_chat.py#L85

def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sequence_len = 4096
        self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
        # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json

@dimichgh
Copy link

dimichgh commented Sep 20, 2023

Here's my llama2_chat.py version that worked for me, just make sure you remove adding pad token here: https://github.com/OpenAccess-AI-Collective/axolotl/blob/ec0958f4f846236ac2703dd644f6dac4365f64b4/src/axolotl/utils/models.py#L80
And if you still need padding, then better update model separately with re-sizing the model before you use it to train with this script.

"""
Prompt Strategy for finetuning Llama2 chat models
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.

This implementation is based on the Vicuna PR and the fastchat repo, see also:
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847

Use dataset type: "llama2_chat" in conig.yml to use this prompt style.

E.g. in the config.yml:
datasets:
  - path: llama_finetune_train.jsonl
    type: llama2_chat

The dataset itself should look like this:
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}

in a jsonl file. The first message should be from the human, the second from gpt.
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).

Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
"""

import logging
from dataclasses import dataclass, field
from typing import Generator, List, Sequence

from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
import traceback

@dataclass
class Llama2ChatConversation:
    """A class that manages prompt templates and keeps all conversation history.
    copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""

    name: str = "llama2"
    # The system prompt
    system: str = (
        "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
        "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
        "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
        "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
        "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
    )
    roles: Sequence[str] = ("[INST]", "[/INST]")
    messages: List[List[str]] = field(default_factory=list)
    offset: int = 0
    sep = " "
    sep2 = " </s><s>"
    sep3 = " </s>"
    stop_token_ids = [2]

    def get_prompt(self) -> str:
        """Get the prompt for generation."""
        seps = [self.sep, self.sep2]
        ret = ""
        for i, (role, message) in enumerate(self.messages):
            if (i == len(self.messages) - 1) and (role == self.roles[0]):
                # last message is from user (due to length),
                #  return prompt without it for training
                return ret
            if i == 0:
                ret += self.system + message.strip()
            else:
                if i == len(self.messages) - 1:
                    ret += role + " " + message.strip() + self.sep3
                else:
                    ret += role + " " + message.strip() + seps[i % 2]
        return ret

    def append_message(self, role: str, message: str):
        """Append a new message."""
        self.messages.append([role, message])


class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for ShareGPT prompts.
    adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sequence_len = 4096
        # self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
        # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json

    def tokenize_prompt(self, prompt):
        conv = next(self.prompter.build_prompt(prompt))
        conversation_str = conv.get_prompt()

        # Tokenize conversations
        input_ids = self.tokenizer(
            conversation_str,
            return_tensors="pt",
            ## padding="max_length",
            padding=False,
            max_length=self.sequence_len,
            truncation=True,
        ).input_ids[0]
        target = input_ids.clone()

        # Mask targets. Only compute loss on the assistant outputs.
        sep = conv.roles[1]

        ## total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
        total_len = len(target)
        turns = conversation_str.split(conv.sep2)
        cur_len = 0
        target[:cur_len] = IGNORE_TOKEN_ID
        for turn in turns:
            if turn == "":
                break
            # turn_len = len(self.tokenizer(turn).input_ids) - 1
            turn_len = len(self.tokenizer(turn).input_ids)

            parts = turn.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep
            # "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
            instruction_len = len(self.tokenizer(parts[0]).input_ids)

            # Ignore the user instructions
            target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
            cur_len += turn_len 

        target[cur_len:] = IGNORE_TOKEN_ID

        if cur_len < self.sequence_len:
            if cur_len != total_len:
                target[:] = IGNORE_TOKEN_ID
                logging.warning(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )

        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
        input_ids = input_ids.tolist()
        target = target.tolist()
        # this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
        # follows the original llama implementation
        for i in range(2, total_len - 2):
            if input_ids[i] == 29961:
                input_ids[i] = 518
            if target[i] == 29961:
                target[i] = 518
        return {
            "input_ids": input_ids,
            "labels": target,
            "attention_mask": attention_mask,
        }


class Llama2ChatPrompter:  # pylint: disable=too-few-public-methods
    """
    A prompter that generates prompts for Llama2 models.
    """

    system_prompt = (
        "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
        "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
        "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
        "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
        "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
    )

    def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
        # see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
        source = source["conversations"]  # fix data structure for datasets

        # if system prompt provided, use it
        if source[0]["from"] == "system":
            system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
            source = source[1:]
        else:
            system = self.system_prompt

        conv = Llama2ChatConversation(system=system)

        if len(source) < 2:
            # If there isn't a back and forth conversation, ignore it
            # also happens on the data splitting leaving empty conversations
            raise IndexError

        roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []  # pylint: disable=R0801
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
            if sentence["value"]:
                conv.append_message(role, sentence["value"])
        yield conv


def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
    return LLama2ChatTokenizingStrategy(
        Llama2ChatPrompter(),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )

@kaldeberger
Copy link
Author

@dimichgh thanks. Can you let us know if you've gotten good results from your fine tuning with your changes?

I have resorted to modifying the AlpacaPrompter with the llama2 prompt format and this has yielded quite good results. I'm not sure if I want to mess with llama2_chat until some dev takes the time and fixes it.

@vibhorag101
Copy link

Can you please share your aplacaPrompter ? Also does it need the dataset to be formatted in alpaca style or sharegpt ?

@winglian
Copy link
Collaborator

One thing you can try is the branch in PR #578

Simply set :

type: sharegpt
conversation: llama-2

@kaldeberger
Copy link
Author

kaldeberger commented Sep 26, 2023

@vibhorag101 I'm sorry but I don't feel comfortable sharing it, because I don't really know what I'm doing. I wouldn't want people to waste time on something if I'm not on the right track. But essentially, if you want to try it yourself -
I have taken the prompt format from this blog post: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
and edited it here: https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/prompt_strategies/alpaca_w_system.py#L55 (remove the "### System:" etc. formatting and replace it with <<SYS>><</SYS>> etc.)

@dimichgh I have tried your changes and the fine tune process itself works, but the model will not produce meaningful output and cannot be quantized because unexpected tensor dimension errors by llama.cpp.

@winglian thanks, will try it out later this week.

@kaldeberger
Copy link
Author

One thing you can try is the branch in PR #578

Simply set :

type: sharegpt
conversation: llama-2

I tried it and it seems to somewhat work (at least fine tuning, converting, quantization and inference all finish without issues), but the model doesn't learn too well.

With my hacked together alpaca prompt prompt_strategy that uses the llama2 message format I get a loss of around 1.0 where the same dataset with your suggestion gets about 1.4. That also reflects in the much poorer response quality of the resulting model with your solution.

Plus, I was trying to understand your code but basing the classes off of sharegpt made it all even more difficult to understand. I really don't like the code layout there, makes it hard to follow what's used where, sorry to be so blunt.

@pshivraj
Copy link

I see similar behaviour @kaldeberger saw, I was getting loss trend down from 0.9 to 0.2 after an epoch on my dataset, however switching to new prompt strategy I see loss trend down from 1.5 to 0.5 after an epoch.
@winglian Do you see similar behaviour based on the some tests you might have done ?

@winglian
Copy link
Collaborator

I see similar behaviour @kaldeberger saw, I was getting loss trend down from 0.9 to 0.2 after an epoch on my dataset, however switching to new prompt strategy I see loss trend down from 1.5 to 0.5 after an epoch.

@winglian Do you see similar behaviour based on the some tests you might have done ?

@pshivraj cannyou clarify which prompt strategies were giving which loss results please?

@pshivraj
Copy link

Hi @winglian Sorry for not mentioning this beforehand.
I changed from
type: sharegpt:chat to

type: sharegpt
conversation: llama-2

@dimichgh
Copy link

@vibhorag101 I'm sorry but I don't feel comfortable sharing it, because I don't really know what I'm doing. I wouldn't want people to waste time on something if I'm not on the right track. But essentially, if you want to try it yourself - I have taken the prompt format from this blog post: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 and edited it here: https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/prompt_strategies/alpaca_w_system.py#L55 (remove the "### System:" etc. formatting and replace it with <<SYS>><</SYS>> etc.)

@dimichgh I have tried your changes and the fine tune process itself works, but the model will not produce meaningful output and cannot be quantized because unexpected tensor dimension errors by llama.cpp.

@winglian thanks, will try it out later this week.

@dimichgh I found some more issues after that, but did not have time to provide a patch for that yet.

@NanoCode012
Copy link
Collaborator

Current code by default sets llamatokenizer's to use llama's EOS as pad token, except for the llama2 chat class above.

And if you still need padding, then better update model separately with re-sizing the model before you use it to train with this script.

Model embed length is automatically resized if mismatch with tokenizer, so you don't need to do it yourself.

It seems that the original issue is solved by using the other class. Regarding the weird loss, it could be due to how the data is tokenized, so providing the debugging output (see readme) could help. I'll close this for now as the original issue is solved. If you would like to dive into the loss, a separate issue might be better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants