Skip to content

Commit

Permalink
make sure to patch for z3
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 24, 2024
1 parent 0fc6645 commit 05be258
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,10 @@ def load_model(
)

LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)

if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
Expand Down

0 comments on commit 05be258

Please sign in to comment.