Skip to content

Commit

Permalink
fix for qwen w lora (axolotl-ai-cloud#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 30, 2023
1 parent a2edaf0 commit c16aa3c
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,22 @@ def load_model(
module.to(torch.float32)

needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False

if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True

if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
if not skip_prepare_model_for_kbit_training:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True

# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
Expand Down

0 comments on commit c16aa3c

Please sign in to comment.