Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gptneox embeddings, fix phi2 inputs, also fix the casting #1083

Merged
merged 3 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/axolotl/utils/lora_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "phi-msft":
return ["embd", "lm_head.linear"]
return ["lm_head", "embed_tokens"]
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
return ["embed_tokens", "lm_head"]
10 changes: 4 additions & 6 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,14 @@ def load_model(
log_gpu_memory_usage(LOG, "after model load", model.device)

# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if "lm_head" in name or "embed_tokens" in name:
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32)

Expand All @@ -619,15 +620,12 @@ def load_model(

# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (
cfg.flash_attention
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
):
if needs_fa2_dtype or cfg.flash_attention:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
winglian marked this conversation as resolved.
Show resolved Hide resolved
if "norm" in name:
module.to(cfg.torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)

Expand Down
1 change: 1 addition & 0 deletions tests/core/test_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def fixture_cfg():
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"model_config_type": "llama",
}
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def test_phi2_add_tokens_adapter(self):
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embd", "lm_head.linear"],
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
}
)

Expand Down