Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error on finetuning mistral-nemo-2407 model with peft LoRA #1937

Closed
2 of 4 tasks
YeonwooSung opened this issue Jul 19, 2024 · 2 comments
Closed
2 of 4 tasks

Error on finetuning mistral-nemo-2407 model with peft LoRA #1937

YeonwooSung opened this issue Jul 19, 2024 · 2 comments

Comments

@YeonwooSung
Copy link

System Info

peft 0.11.1
transformers 4.42.4
Python 3.10.13
trl 0.9.6

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Trying to finetune the "Mistral-Nemo" model with SFTTrainer, but keep facing some runtime error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-29-95658755d61a>](https://localhost:8080/#) in <cell line: 3>()
      1 RESUME_FROM_CHECKPOINT=False
      2 
----> 3 trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)

36 frames
[/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py](https://localhost:8080/#) in train(self, *args, **kwargs)
    449             self.model = self._trl_activate_neftune(self.model)
    450 
--> 451         output = super().train(*args, **kwargs)
    452 
    453         # After training we make sure to retrieve back the original forward pass method

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1930                 hf_hub_utils.enable_progress_bars()
   1931         else:
-> 1932             return inner_training_loop(
   1933                 args=args,
   1934                 resume_from_checkpoint=resume_from_checkpoint,

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2266 
   2267                 with self.accelerator.accumulate(model):
-> 2268                     tr_loss_step = self.training_step(model, inputs)
   2269 
   2270                 if (

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in training_step(self, model, inputs)
   3305 
   3306         with self.compute_loss_context_manager():
-> 3307             loss = self.compute_loss(model, inputs)
   3308 
   3309         del inputs

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in compute_loss(self, model, inputs, return_outputs)
   3336         else:
   3337             labels = None
-> 3338         outputs = model(**inputs)
   3339         # Save past state if it exists
   3340         # TODO: this needs to be fixed and made cleaner later.

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py](https://localhost:8080/#) in forward(*args, **kwargs)
    817 
    818     def forward(*args, **kwargs):
--> 819         return model_forward(*args, **kwargs)
    820 
    821     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`

[/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    805 
    806     def __call__(self, *args, **kwargs):
--> 807         return convert_to_fp32(self.model_forward(*args, **kwargs))
    808 
    809     def __getstate__(self):

[/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py](https://localhost:8080/#) in decorate_autocast(*args, **kwargs)
     14     def decorate_autocast(*args, **kwargs):
     15         with autocast_instance:
---> 16             return func(*args, **kwargs)
     17 
     18     decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode"  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/peft/peft_model.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1428             with self._enable_peft_forward_hooks(**kwargs):
   1429                 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1430                 return self.base_model(
   1431                     input_ids=input_ids,
   1432                     attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    177 
    178     def forward(self, *args: Any, **kwargs: Any):
--> 179         return self.model.forward(*args, **kwargs)
    180 
    181     def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:

[/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in new_forward(module, *args, **kwargs)
    167                 output = module._old_forward(*args, **kwargs)
    168         else:
--> 169             output = module._old_forward(*args, **kwargs)
    170         return module._hf_hook.post_forward(module, output)
    171 

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1198 
   1199         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1200         outputs = self.model(
   1201             input_ids=input_ids,
   1202             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in new_forward(module, *args, **kwargs)
    167                 output = module._old_forward(*args, **kwargs)
    168         else:
--> 169             output = module._old_forward(*args, **kwargs)
    170         return module._hf_hook.post_forward(module, output)
    171 

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    963 
    964             if self.gradient_checkpointing and self.training:
--> 965                 layer_outputs = self._gradient_checkpointing_func(
    966                     decoder_layer.__call__,
    967                     hidden_states,

[/usr/local/lib/python3.10/dist-packages/torch/_compile.py](https://localhost:8080/#) in inner(*args, **kwargs)
     22             import torch._dynamo
     23 
---> 24             return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
     25 
     26         return inner

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    449             prior = set_eval_frame(callback)
    450             try:
--> 451                 return fn(*args, **kwargs)
    452             finally:
    453                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
     34     @functools.wraps(fn)
     35     def inner(*args, **kwargs):
---> 36         return fn(*args, **kwargs)
     37 
     38     return inner

[/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py](https://localhost:8080/#) in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
    485                 "use_reentrant=False."
    486             )
--> 487         return CheckpointFunction.apply(function, preserve, *args)
    488     else:
    489         gen = _checkpoint_without_reentrant_generator(

[/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py](https://localhost:8080/#) in apply(cls, *args, **kwargs)
    596             # See NOTE: [functorch vjp and autograd interaction]
    597             args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598             return super().apply(*args, **kwargs)  # type: ignore[misc]
    599 
    600         if not is_setup_ctx_defined:

[/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py](https://localhost:8080/#) in forward(ctx, run_function, preserve_rng_state, *args)
    260 
    261         with torch.no_grad():
--> 262             outputs = run_function(*args)
    263         return outputs
    264 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in new_forward(module, *args, **kwargs)
    167                 output = module._old_forward(*args, **kwargs)
    168         else:
--> 169             output = module._old_forward(*args, **kwargs)
    170         return module._hf_hook.post_forward(module, output)
    171 

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    716 
    717         # Self Attention
--> 718         hidden_states, self_attn_weights, present_key_value = self.self_attn(
    719             hidden_states=hidden_states,
    720             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in new_forward(module, *args, **kwargs)
    167                 output = module._old_forward(*args, **kwargs)
    168         else:
--> 169             output = module._old_forward(*args, **kwargs)
    170         return module._hf_hook.post_forward(module, output)
    171 

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    611         bsz, q_len, _ = hidden_states.size()
    612 
--> 613         query_states = self.q_proj(hidden_states)
    614         key_states = self.k_proj(hidden_states)
    615         value_states = self.v_proj(hidden_states)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py](https://localhost:8080/#) in forward(self, x, *args, **kwargs)
    478                         output = output.to(expected_dtype)
    479 
--> 480                     result = result + output
    481 
    482             return result

RuntimeError: The size of tensor a (4096) must match the size of tensor b (5120) at non-singleton dimension 2

Below is my code:

class Arguments:
    local_rank:int = -1
    per_device_train_batch_size = 1
    per_device_eval_batch_size = 1

    learning_rate = 2e-4
    max_grad_norm = 0.3
    weight_decay = 0.001

    lora_alpha = 16
    lora_dropout = 0.05
    lora_r = 64
    max_seq_length = g_seq_len

    # <https://huggingface.co/mistralai/Mistral-Nemo-Base-2407>
    model_name = "mistralai/Mistral-Nemo-Base-2407"
    dataset_name = "kaist-ai/CoT-Collection"

    new_model = "Mistral-Nemo-CoT"

    use_4bit = True
    use_nested_quant = False
    bnb_4bit_compute_dtype = "float16"
    bnb_4bit_quant_type = "nf4"

    num_train_epochs = 3

    fp16 = False
    bf16 = True

    gradient_accumulation_steps = 4 #1
    packing = False
    gradient_checkpointing = True
    optim = "paged_adamw_32bit"
    lr_scheduler_type = "constant" # Constant a bit better than cosine, and has advantage for analysis

    max_steps: int = 80000
    warmup_ratio: float = 0.03
    group_by_length: bool = True # "Group sequences into batches with same length. Saves memory and speeds up training considerably."

    save_steps: int = 200
    logging_steps: int = 200


script_args = Arguments()

def create_and_prepare_model(args):
    compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=args.use_4bit,
        bnb_4bit_quant_type=args.bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=args.use_nested_quant,
    )

    if compute_dtype == torch.float16 and args.use_4bit:
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            if not args.bf16:
                print("=" * 80)
                print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
                print("=" * 80)

                args.bf16 = True
            else:
                print("=" * 80)
                print("Using --bf16 option to accelerate training")
                print("=" * 80)

    device_map = {"": 0}

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        quantization_config=bnb_config,
        device_map=device_map,
        trust_remote_code=True,
    )
    model.enable_input_require_grads()
    print(model)

    peft_config = LoraConfig(
        lora_alpha=script_args.lora_alpha,
        lora_dropout=script_args.lora_dropout,
        r=script_args.lora_r,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            'q_proj',
            'k_proj',
            'v_proj',
            'o_proj',
        ],
    )

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    return model, peft_config, tokenizer

training_arguments = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=script_args.per_device_train_batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    optim=script_args.optim,
    save_steps=script_args.save_steps,
    logging_steps=script_args.logging_steps,
    learning_rate=script_args.learning_rate,
    fp16=script_args.fp16,
    bf16=script_args.bf16,
    max_grad_norm=script_args.max_grad_norm,
    max_steps=script_args.max_steps,
    warmup_ratio=script_args.warmup_ratio,
    group_by_length=script_args.group_by_length,
    lr_scheduler_type=script_args.lr_scheduler_type,
    run_name="mistral-nemo-cot",
)


model, peft_config, tokenizer = create_and_prepare_model(script_args)
model.config.use_cache = False

dataset = load_dataset(script_args.dataset_name, split="train[11%:15%](pct1_dropremainder)", trust_remote_code=True)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    max_seq_length=script_args.max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=script_args.packing,
)

for name, module in trainer.model.named_modules():
    if isinstance(module, LoraLayer):
        if script_args.bf16:
            module = module.to(torch.bfloat16)
    if "norm" in name:
        module = module.to(torch.float32)
    if "lm_head" in name or "embed_tokens" in name:
        if hasattr(module, "weight"):
            if script_args.bf16 and module.weight.dtype == torch.float32:
                module = module.to(torch.bfloat16)


trainer.train()

Expected behavior

The code works perfectly fine with other models (Mistral-7B, Llama3-8B, etc), but keep having tensor size mismatch error for the "mistral-nemo" model..

@vicgalle
Copy link

HF transformers main branch already has a fix for this (huggingface/transformers#32065), so if you install transformers from main, it will work (I confirm can fine-tune using peft with this fix)

@YeonwooSung
Copy link
Author

@vicgalle Cool! Thanks for the reply :)
Will close this issue then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@vicgalle @YeonwooSung and others