Skip to content

Commit

Permalink
fix: warn user to install mamba_ssm package
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Dec 29, 2023
1 parent dec66d7 commit 726dfe0
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/axolotl/models/mamba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
Modeling module for Mamba models
"""

import importlib


def check_mamba_ssm_installed():
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
if mamba_ssm_spec is None:
raise ImportError(
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm,flash-attn]`"
)


def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple
Expand All @@ -10,3 +20,6 @@ def fix_mamba_attn_for_loss():

mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name


check_mamba_ssm_installed()

0 comments on commit 726dfe0

Please sign in to comment.