Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Mistral dropoout #683

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers import MistralConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv

from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.utils import GaussianDropout, get_cu_seqlens_from_pos_ids

LOG = logging.getLogger("axolotl.monkeypatch.mistral")

Expand Down Expand Up @@ -479,6 +480,12 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
"""

def __init__(self, config: MistralConfig):
super().__init__(config)
self.dropout_p = config.dropout_p if hasattr(config, "dropout_p") else None
if self.dropout_p:
self.dropout = GaussianDropout(p=self.dropout_p)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -528,6 +535,8 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.training and self.dropout_p:
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down
18 changes: 18 additions & 0 deletions src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Shared utils for the monkeypatches
"""
import torch
from torch import nn


def get_cu_seqlens(attn_mask):
Expand Down Expand Up @@ -101,3 +102,20 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len)

return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)


class GaussianDropout(nn.Module):
"""
Module to apply Gaussian Dropout
"""

def __init__(self, p=0.5): # pylint: disable=invalid-name
super().__init__()
if p <= 0 or p >= 1:
raise ValueError("p value should accomplish 0 < p < 1")
self.p = p # pylint: disable=invalid-name

def forward(self, inputs):
stddev = (self.p / (1.0 - self.p)) ** 0.5
epsilon = torch.randn_like(inputs) * stddev
return inputs * epsilon
47 changes: 23 additions & 24 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
Expand Down Expand Up @@ -99,7 +98,6 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
model_config = load_model_config(cfg)

Expand Down Expand Up @@ -217,20 +215,21 @@ def load_model(
if cfg.flash_attention and not cfg.sample_packing:
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
model_kwargs["use_flash_attention_2"] = True

if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)

if cfg.rope_scaling:
setattr(model_config, "rope_scaling", cfg.rope_scaling)

try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM

config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
Expand Down Expand Up @@ -268,6 +267,7 @@ def load_model(

model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
config=model_config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
Expand All @@ -278,6 +278,7 @@ def load_model(
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
Expand All @@ -286,6 +287,7 @@ def load_model(
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
Expand All @@ -294,30 +296,26 @@ def load_model(
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
config.max_seq_len = cfg.sequence_len
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
Expand All @@ -326,7 +324,7 @@ def load_model(
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
Expand All @@ -341,6 +339,7 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
Expand Down