Skip to content

Commit

Permalink
improve checks for various falcon model sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 13, 2023
1 parent 7070825 commit 266ce55
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ def load_model(
if cfg.flash_attention:
replace_btlm_attn_with_flash_attn()

if hasattr(model_config, "model_type") and model_config.model_type == "falcon":
if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
replace_falcon_attn_with_flash_attn()

Expand Down

0 comments on commit 266ce55

Please sign in to comment.