Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for mamba #915

Merged
merged 15 commits into from
Dec 9, 2023
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -U -e .[flash-attn,mamba-ssm]
pip3 install -r requirements-tests.txt

- name: Run e2e tests
Expand Down
61 changes: 61 additions & 0 deletions examples/mamba/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5

train_on_inputs: false
group_by_length: true

bf16: true
fp16: false
tf32: true

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:

warmup_steps: 10
eval_steps:
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 0.25
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,8 @@ def parse_requirements():
"deepspeed": [
"deepspeed",
],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
},
)
55 changes: 48 additions & 7 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

Expand All @@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers
"""

model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
Expand Down Expand Up @@ -285,6 +291,32 @@ def compute_loss(self, model, inputs, return_outputs=False):
return super().compute_loss(model, inputs, return_outputs=return_outputs)


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""

def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()

loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)

return lm_loss


class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
Expand Down Expand Up @@ -462,6 +494,8 @@ def _get_trainer_cls(self):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
return AxolotlTrainer

def build(self, total_num_steps):
Expand Down Expand Up @@ -529,7 +563,7 @@ def build(self, total_num_steps):
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy

if self.cfg.save_safetensors:
if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors

if self.cfg.sample_packing_eff_est:
Expand Down Expand Up @@ -677,6 +711,7 @@ def build(self, total_num_steps):
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
Expand Down Expand Up @@ -731,11 +766,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -755,3 +786,13 @@ def build(self, total_num_steps):
] = self.cfg.micro_batch_size

return trainer

def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)
12 changes: 12 additions & 0 deletions src/axolotl/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Modeling module for Mamba models
"""


def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple

from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed

mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
42 changes: 42 additions & 0 deletions src/axolotl/models/mamba/configuration_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig


class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""

model_type = "mamba"

def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
128 changes: 128 additions & 0 deletions src/axolotl/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# pylint: skip-file
import os
from collections import namedtuple
from functools import partial
from typing import Optional, Union

import torch
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss

from axolotl.models.mamba.configuration_mamba import MambaConfig


class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()

def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)

def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)

CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)

else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
):
if state_dict is None:
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model
6 changes: 4 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def train(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)

model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False

# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
Expand All @@ -92,7 +93,8 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
Expand Down
Loading