Skip to content

Commit

Permalink
add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 11, 2024
1 parent 23495a8 commit 78c5b19
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/axolotl/utils/lora_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "phi-msft":
return ["embd", "lm_head.linear"]
return ["lm_head", "embed_tokens"]
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
return ["embed_tokens", "lm_head"]
10 changes: 4 additions & 6 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,14 @@ def load_model(
log_gpu_memory_usage(LOG, "after model load", model.device)

# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if "lm_head" in name or "embed_tokens" in name:
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32)

Expand All @@ -619,15 +620,12 @@ def load_model(

# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (
cfg.flash_attention
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
):
if needs_fa2_dtype or cfg.flash_attention:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
module.to(cfg.torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)

Expand Down
1 change: 1 addition & 0 deletions tests/core/test_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def fixture_cfg():
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"model_config_type": "llama",
}
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def test_phi2_add_tokens_adapter(self):
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embd", "lm_head.linear"],
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
}
)

Expand Down

0 comments on commit 78c5b19

Please sign in to comment.